80 Commits

Author SHA1 Message Date
Jason A. Donenfeld
b41922e5c8 version: bump snapshot 2018-10-01 17:58:31 +02:00
Jason A. Donenfeld
dbb72402f2 Adding missing queueconstants file 2018-10-01 16:11:31 +02:00
Chris Branch
7c971d7ef4 Fix transport message length check
wireguard-go has a bad length check in its transport message handling.
Although it cannot be exploited because of another length check earlier in the
function, this should be fixed regardless.
2018-09-25 05:18:11 +02:00
Jason A. Donenfeld
70bcf9ecb8 Make it easy to restrict queue sizes more 2018-09-25 02:31:02 +02:00
Jason A. Donenfeld
ebc7541953 Fix shutdown races 2018-09-24 01:52:02 +02:00
Jason A. Donenfeld
833597b585 More pooling 2018-09-24 00:37:43 +02:00
Jason A. Donenfeld
cf81a28dd3 Fixup buffer freeing 2018-09-22 05:43:03 +02:00
Jason A. Donenfeld
942abf948a send: more precise padding calculation 2018-09-16 23:42:31 +02:00
Jason A. Donenfeld
47d1140361 device: preallocated buffers scheme
Not useful now but quite possibly later.
2018-09-16 23:10:19 +02:00
Jason A. Donenfeld
39d6e4f2f1 Change queueing drop order and fix memory leaks
If the queues are full, we drop the present packet, which is better for
network traffic flow. Also, we try to fix up the memory leaks with not
putting buffers from our shared pool.
2018-09-16 21:50:58 +02:00
Jason A. Donenfeld
1c02557013 send: use accessor function for buffer pool 2018-09-16 18:49:19 +02:00
Mathias Hall-Andersen
32d2148835 Fixed port overwrite issue on kernels without ipv6
Fixed an issue in CreateBind for Linux:
If ipv6 was not supported the error code would be
correctly identified as EAFNOSUPPORT and ipv4 binding attempted.
However the port would be set to 0,
which results in the subsequent create4 call requesting
a random port rather than the one provided to CreateBind.

This issue was identified by:
Kent Friis <leeloored@gmx.com>
2018-09-16 18:49:19 +02:00
Jason A. Donenfeld
5be541d147 global: fix up copyright headers 2018-09-16 18:49:19 +02:00
Jason A. Donenfeld
063becdc73 uapi: insert peer version placeholder
While we don't want people to ever use old protocols, people will
complain if the API "changes", so explicitly make the unset protocol
mean the latest, and add a dummy mechanism of specifying the protocol on
a per-peer basis, which we hope nobody actually ever uses.
2018-09-02 23:04:47 -06:00
Jason A. Donenfeld
15da869b31 Fix duplicate copyright line 2018-07-30 05:14:17 +02:00
Jason A. Donenfeld
3ad3e83c7a uapi: allow overriding socket directory at compile time 2018-07-24 14:32:35 +02:00
Jason A. Donenfeld
2e13b7b0fb send: better debug message for failed data packet 2018-07-16 16:05:36 +02:00
Jason A. Donenfeld
6b3b1c3b91 version: bump snapshot 2018-06-13 16:22:16 +02:00
Jason A. Donenfeld
6a5d0e2bcd Support IPv6-less kernels 2018-06-12 01:32:46 +02:00
Jason A. Donenfeld
0ba551807f Do not build tun device on ios 2018-06-09 03:31:17 +02:00
Jason A. Donenfeld
99d5aeeb27 Fix duplicated wording 2018-06-02 17:36:35 +02:00
Jason A. Donenfeld
a050431f26 Makefile: export PWD for OpenBSD's ksh(1)
Interestingly, ksh(1) on OpenBSD does not export PWD by default, and it
also has a notion of the "logical cwd" vs the "physical cwd", with the
latter being passed to chdir, but the former being stored in the
non-exported PWD and displayed to the user. This means that if you `cd`
into a directory that's comprised of symlinks, exec'd processes will see
the physical path. Observe:

  # ksh
  # mkdir a
  # ln -s a b
  # cd b
  # pwd
  /root/b
  # ksh -c pwd
  /root/a

The fact of separating physical and logical paths is not too uncommon
for shells (bash does it too), but not exporting PWD is very odd.

Since this is common behavior for many shells, libraries that return the
working directory will do something strange: they `stat(".")` and then
`stat(getenv("PWD"))`, and if these point to the same inode, they roll
with the value of `getenv("PWD")`, or otherwise fallback to asking the
kernel for the cwd.

Since PWD was not exported by ksh(1), Go's dep utility did not understand
it was operating inside of our faked GOPATH and became upset.

This patch works around the whole situation by simply exporting PWD
before executing dep.
2018-06-02 16:36:12 +02:00
Jason A. Donenfeld
0c976003c8 version: bump snapshot 2018-05-31 02:26:07 +02:00
Jason A. Donenfeld
955e89839f Print version number in log 2018-05-30 01:09:18 +02:00
Jason A. Donenfeld
a4cd0216c0 Update deps 2018-05-28 01:39:37 +02:00
Jason A. Donenfeld
1d7845a600 Fix typo in timers 2018-05-27 22:55:15 +02:00
Jason A. Donenfeld
5079298ce2 Disable broadcast mode on *BSD
Keeping it on makes IPv6 problematic and confuses routing daemons.
2018-05-27 22:55:15 +02:00
Jason A. Donenfeld
fc3a7635e5 Disappointing anti-sticky experiment 2018-05-27 22:55:15 +02:00
Jason A. Donenfeld
2496cdd8e6 Fix tests 2018-05-24 19:58:16 +02:00
Jason A. Donenfeld
4365b4583f Trick for being extra sensitive to route changes 2018-05-24 18:21:14 +02:00
Jason A. Donenfeld
bbf320c477 Back to sticky sockets on android 2018-05-24 17:53:00 +02:00
Jason A. Donenfeld
625d59da14 Do not build on Linux 2018-05-24 16:41:42 +02:00
Jason A. Donenfeld
2f2eca8947 Catch EINTR 2018-05-24 15:36:29 +02:00
Jason A. Donenfeld
66f6ca3e4a Remove old makefile artifact 2018-05-24 03:13:46 +02:00
Jason A. Donenfeld
e6657638fc version: bump snapshot 2018-05-24 02:25:51 +02:00
Jason A. Donenfeld
4a9de3218e Add undocumented --version flag 2018-05-24 02:25:36 +02:00
Jason A. Donenfeld
28a167e828 Eye before ee except after see 2018-05-23 19:00:00 +02:00
Jason A. Donenfeld
99c6513d60 No zero sequence numbers 2018-05-23 18:30:55 +02:00
Jason A. Donenfeld
8a92a9109a Don't cause a new fake gopath to call dep 2018-05-23 17:31:06 +02:00
Jason A. Donenfeld
0b647d1ca7 Infoleak ifnames and be more permissive
Listing interfaces is already permitted by the OS, so we allow this info
leak too.
2018-05-23 15:38:24 +02:00
Jason A. Donenfeld
588b9f01ae Adopt GOPATH
GOPATH is annoying, but the Go community pushing me to adopt it is even
more annoying.
2018-05-23 05:18:13 +02:00
Jason A. Donenfeld
f70bd1fab3 Remove more windows cruft 2018-05-23 04:46:09 +02:00
Jason A. Donenfeld
40d5ff0c70 Cleanup 2018-05-23 03:58:27 +02:00
Jason A. Donenfeld
5a2228a5c9 Move replay into subpackage 2018-05-23 03:58:27 +02:00
Jason A. Donenfeld
0a63188afa Move tun to subpackage 2018-05-23 03:58:27 +02:00
Jason A. Donenfeld
65a74f3175 Avoid sticky sockets on Android
The android policy routing system does insane things.
2018-05-22 23:22:23 +02:00
Jason A. Donenfeld
b4cef2524f Fix integer conversions 2018-05-22 18:35:52 +02:00
Jason A. Donenfeld
7038de95e1 Bump dependencies for OpenBSD 2018-05-22 17:58:34 +02:00
Jason A. Donenfeld
82d12e85bb Fix markdown 2018-05-22 16:47:15 +02:00
Jason A. Donenfeld
d6b694e161 Add OpenBSD tun driver support 2018-05-22 16:21:05 +02:00
Jason A. Donenfeld
794e494802 Fix code duplication 2018-05-22 14:59:29 +02:00
Jason A. Donenfeld
dd663a7ba4 Notes on FreeBSD limitations 2018-05-22 01:30:16 +02:00
Jason A. Donenfeld
8462c08cf2 Just in case darwin changes, we also shutdown 2018-05-22 01:27:29 +02:00
Jason A. Donenfeld
b8c9e13c6e Call shutdown on route socket on freebsd 2018-05-22 01:26:47 +02:00
Filippo Valsorda
bc05eb1c3c Minor main.go signal fixes
* Buffer the signal channel as it's non-blocking on the sender side
* Notify on SIGTERM instead of the uncatchable SIGKILL

License: MIT
Signed-off-by: Filippo Valsorda <valsorda@google.com>
2018-05-21 20:22:12 +02:00
Filippo Valsorda
7a527f7c89 Fix Sscanf use in tun_darwin
License: MIT
Signed-off-by: Filippo Valsorda <valsorda@google.com>
2018-05-21 20:21:31 +02:00
Filippo Valsorda
84f52ce0d6 Make successful tests silent
License: MIT
Signed-off-by: Filippo Valsorda <valsorda@google.com>
2018-05-21 20:21:00 +02:00
Filippo Valsorda
7bdc5eb54e Properly close DummyTUN to avoid deadlock in TestNoiseHandshake
License: MIT
Signed-off-by: Filippo Valsorda <valsorda@google.com>
2018-05-21 20:20:13 +02:00
Jason A. Donenfeld
1c666576d5 User cookie is closer to fwmark than setfib 2018-05-21 20:13:39 +02:00
Jason A. Donenfeld
2ae22ac65d Remove broken windows cruft 2018-05-21 19:00:58 +02:00
Jason A. Donenfeld
ff3f2455e5 Rework freebsd support 2018-05-21 18:48:48 +02:00
Brady OBrien
b962d7d791 Add FreeBSD support
Signed-off-by: Brady OBrien <brady.obrien128@gmail.com>
2018-05-21 17:31:22 +02:00
Jason A. Donenfeld
837a12c841 Close events channel when no status listener 2018-05-21 14:16:46 +02:00
Jason A. Donenfeld
7472930d4e Straighten out UAPI logging 2018-05-21 03:38:50 +02:00
Jason A. Donenfeld
6307bfcdf4 Close hack listener before closing channel 2018-05-21 03:31:46 +02:00
Jason A. Donenfeld
e28d70f5b2 ratelimiter: do not run GC with nothing to do 2018-05-21 03:20:18 +02:00
Jason A. Donenfeld
84c5357cf3 Reasonable punctuation given the spacing 2018-05-21 02:50:39 +02:00
Jason A. Donenfeld
acb5481246 Fix data races in timers 2018-05-20 06:50:07 +02:00
Jason A. Donenfeld
18f43705ec Fix race with closing event channel
There's still a tiny race on Linux, since the tun channel is written to
from two places.
2018-05-20 06:38:39 +02:00
Jason A. Donenfeld
058cedcf66 Style 2018-05-20 06:29:46 +02:00
Jason A. Donenfeld
c5fa3de24c Remove unused mtu variable 2018-05-20 06:29:21 +02:00
Jason A. Donenfeld
1068d6b92b Give bind its own wait group
In a waitgroup, all waits must come after all adds
2018-05-20 06:29:21 +02:00
Jason A. Donenfeld
5e924e5407 Avoid deadlock when the mutex isn't required, since these are atomics
Maybe this fixes the "double lock issue" in
f73d2fb2d96bc3fbc8bc4cce452e3c19689de01e?
2018-05-20 06:29:21 +02:00
Jason A. Donenfeld
b290cf05e3 Use proper status listener on Darwin 2018-05-20 06:29:21 +02:00
Jason A. Donenfeld
b95a4c61a5 Reduce the hack listener to once a second 2018-05-20 04:03:11 +02:00
Jason A. Donenfeld
a5b3340e5b Fix race in netlink peer correlator 2018-05-20 03:37:42 +02:00
Jason A. Donenfeld
7c21a3de0a Fix race in lock pending 2018-05-20 03:31:27 +02:00
Jason A. Donenfeld
0a68c1ab17 Fix race in stats 2018-05-20 03:26:46 +02:00
Jason A. Donenfeld
e04f9543c0 Fix race in packetInNonceQueueIsAwaitingKey 2018-05-20 03:24:14 +02:00
Jason A. Donenfeld
fa003b6933 Discourage building for Linux 2018-05-20 03:19:03 +02:00
64 changed files with 1939 additions and 1188 deletions

2
.gitignore vendored
View File

@@ -1,2 +1,4 @@
wireguard-go wireguard-go
vendor vendor
.gopath
ireallywantobuildon_linux.go

34
Gopkg.lock generated
View File

@@ -1,16 +1,42 @@
# This was generated by ./generate-vendor.sh # This file is autogenerated, do not edit; changes may be undone by the next 'dep ensure'.
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
revision = "1a580b3eff7814fc9b40602fd35256c63b50f491" packages = [
"blake2s",
"chacha20poly1305",
"curve25519",
"internal/chacha20",
"poly1305"
]
revision = "ab813273cd59e1333f7ae7bff5d027d4aadf528c"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/net" name = "golang.org/x/net"
revision = "2491c5de3490fced2f6cff376127c667efeed857" packages = [
"bpf",
"internal/iana",
"internal/socket",
"ipv4",
"ipv6"
]
revision = "dfa909b99c79129e1100513e5cd36307665e5723"
[[projects]] [[projects]]
branch = "master" branch = "master"
name = "golang.org/x/sys" name = "golang.org/x/sys"
revision = "7c87d13f8e835d2fb3a70a2912c811ed0c1d241b" packages = [
"cpu",
"unix"
]
revision = "c11f84a56e43e20a78cee75a7c034031ecf57d1f"
[solve-meta]
analyzer-name = "dep"
analyzer-version = 1
inputs-digest = "d85ae9d2b4afafc3d7535505c46368cbbbec350cf876616302c1bcf44f6ec103"
solver-name = "gps-cdcl"
solver-version = 1

View File

@@ -1,4 +1,3 @@
# This was generated by ./generate-vendor.sh
[[constraint]] [[constraint]]
branch = "master" branch = "master"
name = "golang.org/x/crypto" name = "golang.org/x/crypto"
@@ -11,3 +10,6 @@
branch = "master" branch = "master"
name = "golang.org/x/sys" name = "golang.org/x/sys"
[prune]
go-tests = true
unused-packages = true

View File

@@ -2,15 +2,50 @@ PREFIX ?= /usr
DESTDIR ?= DESTDIR ?=
BINDIR ?= $(PREFIX)/bin BINDIR ?= $(PREFIX)/bin
ifeq ($(shell go env GOOS),linux)
ifeq ($(wildcard .git),)
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
else
$(shell printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > ireallywantobuildon_linux.go)
endif
endif
all: wireguard-go all: wireguard-go
wireguard-go: $(wildcard *.go) $(wildcard */*.go) export GOPATH := $(CURDIR)/.gopath
go build -v -o $@ export PATH := $(PATH):$(CURDIR)/.gopath/bin
GO_IMPORT_PATH := git.zx2c4.com/wireguard-go
version.go:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package main\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
[ "$$(cat $@ 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > $@ && \
git update-index --assume-unchanged $@ || true
.gopath/.created:
rm -rf .gopath
mkdir -p $(dir .gopath/src/$(GO_IMPORT_PATH))
ln -s ../../.. .gopath/src/$(GO_IMPORT_PATH)
touch $@
vendor/.created: Gopkg.toml Gopkg.lock | .gopath/.created
command -v dep >/dev/null || go get -v github.com/golang/dep/cmd/dep
export PWD; cd .gopath/src/$(GO_IMPORT_PATH) && dep ensure -vendor-only -v
touch $@
wireguard-go: $(wildcard *.go) $(wildcard */*.go) .gopath/.created vendor/.created version.go
go build -v $(GO_IMPORT_PATH)
install: wireguard-go install: wireguard-go
@install -v -d "$(DESTDIR)$(BINDIR)" && install -m 0755 -v wireguard-go "$(DESTDIR)$(BINDIR)/wireguard-go" @install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 wireguard-go "$(DESTDIR)$(BINDIR)/wireguard-go"
clean: clean:
rm -f wireguard-go rm -f wireguard-go
.PHONY: clean install update-dep: | .gopath/.created
command -v dep >/dev/null || go get -v github.com/golang/dep/cmd/dep
cd .gopath/src/$(GO_IMPORT_PATH) && dep ensure -update -v
.PHONY: clean install update-dep version.go

View File

@@ -32,7 +32,7 @@ This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead us
### macOS ### macOS
This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_DARWIN_UTUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable. This runs on macOS using the utun driver. It does not yet support sticky sockets, and won't support fwmarks because of Darwin limitations. Since the utun driver cannot have arbitrary interface names, you must either use `utun[0-9]+` for an explicit interface name or `utun` to have the kernel select one for you. If you choose `utun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
### Windows ### Windows
@@ -40,17 +40,19 @@ It is currently a work in progress to strip out the beginnings of an experiment
### FreeBSD ### FreeBSD
Work in progress, but nothing yet to share. This will run on FreeBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_USER_COOKIE`.
### OpenBSD
This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapped to `SO_RTABLE`. Since the tun driver cannot have arbitrary interface names, you must either use `tun[0-9]+` for an explicit interface name or `tun` to have the program select one for you. If you choose `tun` as the interface name, and the environment variable `WG_TUN_NAME_FILE` is defined, then the actual name of the interface chosen by the kernel is written to the file specified by that variable.
## Building ## Building
You can satisfy dependencies with either `go get -d -v` or `dep ensure -vendor-only`. Then run `make`. As this is a Go project, a `GOPATH` is required. For example, wireguard-go can be built with: This requires an installation of [go](https://golang.org) and of [dep](https://github.com/golang/dep). If dep is not installed, it will be downloaded and built as part of the build process.
``` ```
$ git clone https://git.zx2c4.com/wireguard-go $ git clone https://git.zx2c4.com/wireguard-go
$ cd wireguard-go $ cd wireguard-go
$ export GOPATH="$PWD/gopath"
$ go get -d -v
$ make $ make
``` ```
@@ -74,9 +76,9 @@ $ make
are otherwise in compliance with the GPLv2 for each covered work you convey are otherwise in compliance with the GPLv2 for each covered work you convey
(including without limitation making the Corresponding Source available in (including without limitation making the Corresponding Source available in
compliance with Section 3 of the GPLv2), you are granted the additional compliance with Section 3 of the GPLv2), you are granted the additional
the additional permission to convey through the Apple App Store permission to convey through the Apple App Store non-source executable
non-source executable versions of the Program as incorporated into each versions of the Program as incorporated into each applicable covered work
applicable covered work as Executable Versions only under the Mozilla as Executable Versions only under the Mozilla Public License version 2.0
Public License version 2.0 (https://www.mozilla.org/en-US/MPL/2.0/). (https://www.mozilla.org/en-US/MPL/2.0/).

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,6 +0,0 @@
@echo off
REM builds wireguard for windows
go get
go build -o wireguard-go.exe

11
conn.go
View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -66,8 +65,6 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
return addr, err return addr, err
} }
/* Must hold device and net lock
*/
func unsafeCloseBind(device *Device) error { func unsafeCloseBind(device *Device) error {
var err error var err error
netc := &device.net netc := &device.net
@@ -75,6 +72,7 @@ func unsafeCloseBind(device *Device) error {
err = netc.bind.Close() err = netc.bind.Close()
netc.bind = nil netc.bind = nil
} }
netc.stopping.Wait()
return err return err
} }
@@ -162,10 +160,11 @@ func (device *Device) BindUpdate() error {
// start receiving routines // start receiving routines
device.state.starting.Add(ConnRoutineNumber) device.net.starting.Add(ConnRoutineNumber)
device.state.stopping.Add(ConnRoutineNumber) device.net.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated") device.log.Debug.Println("UDP bind has been updated")
} }

View File

@@ -1,15 +1,18 @@
// +build !linux // +build !linux android
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"golang.org/x/sys/unix"
"net" "net"
"os"
"runtime"
"syscall"
) )
/* This code is meant to be a temporary solution /* This code is meant to be a temporary solution
@@ -85,6 +88,18 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
return conn, uaddr.Port, nil return conn, uaddr.Port, nil
} }
func extractErrno(err error) error {
opErr, ok := err.(*net.OpError)
if !ok {
return nil
}
syscallErr, ok := opErr.Err.(*os.SyscallError)
if !ok {
return nil
}
return syscallErr.Err
}
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) { func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
var err error var err error
var bind NativeBind var bind NativeBind
@@ -92,13 +107,15 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
port := int(uport) port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port) bind.ipv4, port, err = listenNet("udp4", port)
if err != nil { if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
return nil, 0, err return nil, 0, err
} }
bind.ipv6, port, err = listenNet("udp6", port) bind.ipv6, port, err = listenNet("udp6", port)
if err != nil { if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
return nil, 0, err
bind.ipv4.Close() bind.ipv4.Close()
bind.ipv4 = nil
return nil, 0, err return nil, 0, err
} }
@@ -106,8 +123,13 @@ func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
} }
func (bind *NativeBind) Close() error { func (bind *NativeBind) Close() error {
err1 := bind.ipv4.Close() var err1, err2 error
err2 := bind.ipv6.Close() if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
}
if err1 != nil { if err1 != nil {
return err1 return err1
} }
@@ -115,6 +137,9 @@ func (bind *NativeBind) Close() error {
} }
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff) n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil { if endpoint != nil {
endpoint.IP = endpoint.IP.To4() endpoint.IP = endpoint.IP.To4()
@@ -123,6 +148,9 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
} }
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv6.ReadFromUDP(buff) n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err return n, (*NativeEndpoint)(endpoint), err
} }
@@ -131,13 +159,59 @@ func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error var err error
nend := endpoint.(*NativeEndpoint) nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil { if nend.IP.To4() != nil {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else { } else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} }
return err return err
} }
func (bind *NativeBind) SetMark(_ uint32) error { var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (bind *NativeBind) SetMark(mark uint32) error {
if fwmarkIoctl == 0 {
return nil
}
if bind.ipv4 != nil {
fd, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err != nil {
return err
}
}
if bind.ipv6 != nil {
fd, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err != nil {
return err
}
}
return nil return nil
} }

View File

@@ -1,7 +1,8 @@
// +build !android
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
* *
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port * WireGuard's kernelspace implementation. This is more or less a straight port
@@ -16,14 +17,20 @@
package main package main
import ( import (
"./rwcancel"
"errors" "errors"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net" "net"
"strconv" "strconv"
"sync"
"syscall"
"unsafe" "unsafe"
) )
const (
FD_ERR = -1
)
type IPv4Source struct { type IPv4Source struct {
src [4]byte src [4]byte
ifindex int32 ifindex int32
@@ -123,6 +130,7 @@ func createNetlinkRouteSocket() (int, error) {
func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) { func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
var err error var err error
var bind NativeBind var bind NativeBind
var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket() bind.netlinkSock, err = createNetlinkRouteSocket()
if err != nil { if err != nil {
@@ -136,41 +144,63 @@ func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
go bind.routineRouteListener(device) go bind.routineRouteListener(device)
bind.sock6, port, err = create6(port) // attempt ipv6 bind, update port if succesful
bind.sock6, newPort, err = create6(port)
if err != nil { if err != nil {
bind.netlinkCancel.Cancel() if err != syscall.EAFNOSUPPORT {
return nil, port, err bind.netlinkCancel.Cancel()
return nil, 0, err
}
} else {
port = newPort
} }
bind.sock4, port, err = create4(port) // attempt ipv4 bind, update port if succesful
bind.sock4, newPort, err = create4(port)
if err != nil { if err != nil {
bind.netlinkCancel.Cancel() if err != syscall.EAFNOSUPPORT {
unix.Close(bind.sock6) bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
return nil, 0, err
}
} else {
port = newPort
} }
return &bind, port, err
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
return nil, 0, errors.New("ipv4 and ipv6 not supported")
}
return &bind, port, nil
} }
func (bind *NativeBind) SetMark(value uint32) error { func (bind *NativeBind) SetMark(value uint32) error {
err := unix.SetsockoptInt( if bind.sock6 != -1 {
bind.sock6, err := unix.SetsockoptInt(
unix.SOL_SOCKET, bind.sock6,
unix.SO_MARK, unix.SOL_SOCKET,
int(value), unix.SO_MARK,
) int(value),
)
if err != nil { if err != nil {
return err return err
}
} }
err = unix.SetsockoptInt( if bind.sock4 != -1 {
bind.sock4, err := unix.SetsockoptInt(
unix.SOL_SOCKET, bind.sock4,
unix.SO_MARK, unix.SOL_SOCKET,
int(value), unix.SO_MARK,
) int(value),
)
if err != nil { if err != nil {
return err return err
}
} }
bind.lastMark = value bind.lastMark = value
@@ -184,9 +214,14 @@ func closeUnblock(fd int) error {
} }
func (bind *NativeBind) Close() error { func (bind *NativeBind) Close() error {
err1 := closeUnblock(bind.sock6) var err1, err2, err3 error
err2 := closeUnblock(bind.sock4) if bind.sock6 != -1 {
err3 := bind.netlinkCancel.Cancel() err1 = closeUnblock(bind.sock6)
}
if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4)
}
err3 = bind.netlinkCancel.Cancel()
if err1 != nil { if err1 != nil {
return err1 return err1
@@ -199,6 +234,9 @@ func (bind *NativeBind) Close() error {
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive6( n, err := receive6(
bind.sock6, bind.sock6,
buff, buff,
@@ -209,6 +247,9 @@ func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint var end NativeEndpoint
if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive4( n, err := receive4(
bind.sock4, bind.sock4,
buff, buff,
@@ -220,8 +261,14 @@ func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
func (bind *NativeBind) Send(buff []byte, end Endpoint) error { func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint) nend := end.(*NativeEndpoint)
if !nend.isV6 { if !nend.isV6 {
if bind.sock4 == -1 {
return syscall.EAFNOSUPPORT
}
return send4(bind.sock4, nend, buff) return send4(bind.sock4, nend, buff)
} else { } else {
if bind.sock6 == -1 {
return syscall.EAFNOSUPPORT
}
return send6(bind.sock6, nend, buff) return send6(bind.sock6, nend, buff)
} }
} }
@@ -309,7 +356,7 @@ func create4(port uint16) (int, uint16, error) {
) )
if err != nil { if err != nil {
return -1, 0, err return FD_ERR, 0, err
} }
addr := unix.SockaddrInet4{ addr := unix.SockaddrInet4{
@@ -340,7 +387,7 @@ func create4(port uint16) (int, uint16, error) {
return unix.Bind(fd, &addr) return unix.Bind(fd, &addr)
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return -1, 0, err return FD_ERR, 0, err
} }
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
@@ -357,7 +404,7 @@ func create6(port uint16) (int, uint16, error) {
) )
if err != nil { if err != nil {
return -1, 0, err return FD_ERR, 0, err
} }
// set sockopts and bind // set sockopts and bind
@@ -399,7 +446,7 @@ func create6(port uint16) (int, uint16, error) {
}(); err != nil { }(); err != nil {
unix.Close(fd) unix.Close(fd)
return -1, 0, err return FD_ERR, 0, err
} }
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
@@ -551,6 +598,7 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
endpoint *Endpoint endpoint *Endpoint
} }
var reqPeer map[uint32]peerEndpointPtr var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(bind.netlinkSock) defer unix.Close(bind.netlinkSock)
@@ -559,7 +607,7 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
var msgn int var msgn int
for { for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0) msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.ErrorIsEAGAIN(err) { if err == nil || !rwcancel.RetryAfterError(err) {
break break
} }
if !bind.netlinkCancel.ReadyRead() { if !bind.netlinkCancel.ReadyRead() {
@@ -580,7 +628,7 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
switch hdr.Type { switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE: case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers { if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) { if uint(len(remain)) < uint(hdr.Len) {
break break
} }
@@ -596,10 +644,13 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
} }
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 { if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr])) ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil { if reqPeer == nil {
reqPeerLock.Unlock()
break break
} }
pePtr, ok := reqPeer[hdr.Seq] pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok { if !ok {
break break
} }
@@ -620,7 +671,9 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
} }
break break
} }
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr) reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() { go func() {
device.peers.mutex.RLock() device.peers.mutex.RLock()
i := uint32(1) i := uint32(1)
@@ -671,10 +724,12 @@ func (bind *NativeBind) routineRouteListener(device *Device) {
uint32(bind.lastMark), uint32(bind.lastMark),
} }
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg)) nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{ reqPeer[i] = peerEndpointPtr{
peer: peer, peer: peer,
endpoint: &peer.endpoint, endpoint: &peer.endpoint,
} }
reqPeerLock.Unlock()
peer.mutex.RUnlock() peer.mutex.RUnlock()
i++ i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:]) _, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -27,18 +26,14 @@ const (
PaddingMultiple = 16 PaddingMultiple = 16
) )
/* Implementation specific constants */
const ( const (
QueueOutboundSize = 1024 MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
QueueInboundSize = 1024 MaxMessageSize = MaxSegmentSize // maximum size of transport message
QueueHandshakeSize = 1024 MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
MaxMessageSize = MaxSegmentSize // maximum size of transport message
MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
) )
/* Implementation constants */
const ( const (
UnderLoadQueueSize = QueueHandshakeSize / 8 UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected UnderLoadAfterTime = time.Second // how long does the device remain under load after detected

View File

@@ -1,15 +1,14 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"./xchacha20poly1305"
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"git.zx2c4.com/wireguard-go/xchacha20poly1305"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"sync" "sync"

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,13 +1,13 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"./ratelimiter" "git.zx2c4.com/wireguard-go/ratelimiter"
"git.zx2c4.com/wireguard-go/tun"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -35,10 +35,12 @@ type Device struct {
} }
net struct { net struct {
mutex sync.RWMutex starting sync.WaitGroup
bind Bind // bind interface stopping sync.WaitGroup
port uint16 // listening port mutex sync.RWMutex
fwmark uint32 // mark value (0 = disabled) bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
} }
staticIdentity struct { staticIdentity struct {
@@ -64,7 +66,12 @@ type Device struct {
} }
pool struct { pool struct {
messageBuffers sync.Pool messageBufferPool *sync.Pool
messageBufferReuseChan chan *[MaxMessageSize]byte
inboundElementPool *sync.Pool
inboundElementReuseChan chan *QueueInboundElement
outboundElementPool *sync.Pool
outboundElementReuseChan chan *QueueOutboundElement
} }
queue struct { queue struct {
@@ -78,7 +85,7 @@ type Device struct {
} }
tun struct { tun struct {
device TUNDevice device tun.TUNDevice
mtu int32 mtu int32
} }
} }
@@ -162,16 +169,12 @@ func (device *Device) Up() {
return return
} }
device.state.mutex.Lock()
device.isUp.Set(true) device.isUp.Set(true)
device.state.mutex.Unlock()
deviceUpdateState(device) deviceUpdateState(device)
} }
func (device *Device) Down() { func (device *Device) Down() {
device.state.mutex.Lock()
device.isUp.Set(false) device.isUp.Set(false)
device.state.mutex.Unlock()
deviceUpdateState(device) deviceUpdateState(device)
} }
@@ -244,15 +247,7 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
return nil return nil
} }
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg)
}
func NewDevice(tun TUNDevice, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.isUp.Set(false) device.isUp.Set(false)
@@ -260,7 +255,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
device.log = logger device.log = logger
device.tun.device = tun device.tun.device = tunDevice
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
if err != nil { if err != nil {
logger.Error.Println("Trouble determining MTU, assuming default:", err) logger.Error.Println("Trouble determining MTU, assuming default:", err)
@@ -276,11 +271,7 @@ func NewDevice(tun TUNDevice, logger *Logger) *Device {
device.indexTable.Init() device.indexTable.Init()
device.allowedips.Reset() device.allowedips.Reset()
device.pool.messageBuffers = sync.Pool{ device.PopulatePools()
New: func() interface{} {
return new([MaxMessageSize]byte)
},
}
// create queues // create queues
@@ -386,10 +377,11 @@ func (device *Device) Close() {
close(device.signals.stop) close(device.signals.stop)
device.RemoveAllPeers()
device.state.stopping.Wait() device.state.stopping.Wait()
device.FlushPacketQueues() device.FlushPacketQueues()
device.RemoveAllPeers()
device.rate.limiter.Close() device.rate.limiter.Close()
device.state.changing.Set(false) device.state.changing.Set(false)

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -26,8 +25,8 @@ func TestDevice(t *testing.T) {
t.Error("failed to create tun:", err.Error()) t.Error("failed to create tun:", err.Error())
} }
println(tun1) _ = tun1
println(tun2) _ = tun2
// prepare endpoints // prepare endpoints
@@ -41,8 +40,8 @@ func TestDevice(t *testing.T) {
t.Error("failed to create endpoint:", err.Error()) t.Error("failed to create endpoint:", err.Error())
} }
println(end1) _ = end1
println(end2) _ = end2
// create binds // create binds

15
donotuseon_linux.go Normal file
View File

@@ -0,0 +1,15 @@
// +build !android
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
const DoNotUseThisProgramOnLinux = UseTheKernelModuleInstead
// --------------------------------------------------------
// Do not use this on Linux. Instead use the kernel module.
// See wireguard.com/install for more information.
// --------------------------------------------------------

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,20 +0,0 @@
#!/bin/bash
echo "# This was generated by ./generate-vendor.sh" > Gopkg.lock
echo "# This was generated by ./generate-vendor.sh" > Gopkg.toml
while read -r package; do
cat >> Gopkg.lock <<-_EOF
[[projects]]
branch = "master"
name = "$package"
revision = "$(< "$GOPATH/src/$package/.git/refs/heads/master")"
_EOF
cat >> Gopkg.toml <<-_EOF
[[constraint]]
branch = "master"
name = "$package"
_EOF
done < <(sed -n 's/.*"\(golang.org\/x\/[^/]\+\)\/\?.*".*/\1/p' *.go */*.go | sort | uniq)

View File

@@ -1,13 +1,14 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"bytes" "bytes"
"errors"
"git.zx2c4.com/wireguard-go/tun"
"os" "os"
"testing" "testing"
) )
@@ -19,7 +20,7 @@ type DummyTUN struct {
name string name string
mtu int mtu int
packets chan []byte packets chan []byte
events chan TUNEvent events chan tun.TUNEvent
} }
func (tun *DummyTUN) File() *os.File { func (tun *DummyTUN) File() *os.File {
@@ -40,23 +41,29 @@ func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
} }
func (tun *DummyTUN) Close() error { func (tun *DummyTUN) Close() error {
close(tun.events)
close(tun.packets)
return nil return nil
} }
func (tun *DummyTUN) Events() chan TUNEvent { func (tun *DummyTUN) Events() chan tun.TUNEvent {
return tun.events return tun.events
} }
func (tun *DummyTUN) Read(d []byte, offset int) (int, error) { func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
t := <-tun.packets t, ok := <-tun.packets
if !ok {
return 0, errors.New("device closed")
}
copy(d[offset:], t) copy(d[offset:], t)
return len(t), nil return len(t), nil
} }
func CreateDummyTUN(name string) (TUNDevice, error) { func CreateDummyTUN(name string) (tun.TUNDevice, error) {
var dummy DummyTUN var dummy DummyTUN
dummy.mtu = 0 dummy.mtu = 0
dummy.packets = make(chan []byte, 100) dummy.packets = make(chan []byte, 100)
dummy.events = make(chan tun.TUNEvent, 10)
return &dummy, nil return &dummy, nil
} }

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

3
ip.go
View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,13 +1,13 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"crypto/cipher" "crypto/cipher"
"git.zx2c4.com/wireguard-go/replay"
"sync" "sync"
"time" "time"
) )
@@ -23,7 +23,7 @@ type Keypair struct {
sendNonce uint64 sendNonce uint64
send cipher.AEAD send cipher.AEAD
receive cipher.AEAD receive cipher.AEAD
replayFilter ReplayFilter replayFilter replay.ReplayFilter
isInitiator bool isInitiator bool
created time.Time created time.Time
localIndex uint32 localIndex uint32

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

22
main.go
View File

@@ -1,17 +1,18 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"fmt" "fmt"
"git.zx2c4.com/wireguard-go/tun"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall"
) )
const ( const (
@@ -72,6 +73,11 @@ func warning() {
} }
func main() { func main() {
if len(os.Args) == 2 && os.Args[1] == "--version" {
fmt.Printf("wireguard-go v%s\n\nUserspace WireGuard daemon for %s-%s.\nInformation available at https://www.wireguard.com.\nCopyright (C) Jason A. Donenfeld <Jason@zx2c4.com>.\n", WireGuardGoVersion, runtime.GOOS, runtime.GOARCH)
return
}
warning() warning()
// parse arguments // parse arguments
@@ -124,10 +130,10 @@ func main() {
// open TUN device (or use supplied fd) // open TUN device (or use supplied fd)
tun, err := func() (TUNDevice, error) { tun, err := func() (tun.TUNDevice, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return CreateTUN(interfaceName) return tun.CreateTUN(interfaceName, DefaultMTU)
} }
// construct tun device from supplied fd // construct tun device from supplied fd
@@ -138,7 +144,7 @@ func main() {
} }
file := os.NewFile(uintptr(fd), "") file := os.NewFile(uintptr(fd), "")
return CreateTUNFromFile(file) return tun.CreateTUNFromFile(file, DefaultMTU)
}() }()
if err == nil { if err == nil {
@@ -153,6 +159,8 @@ func main() {
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Info.Println("Starting wireguard-go version", WireGuardGoVersion)
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
if err != nil { if err != nil {
@@ -236,7 +244,7 @@ func main() {
logger.Info.Println("Device started") logger.Info.Println("Device started")
errs := make(chan error) errs := make(chan error)
term := make(chan os.Signal) term := make(chan os.Signal, 1)
uapi, err := UAPIListen(interfaceName, fileUAPI) uapi, err := UAPIListen(interfaceName, fileUAPI)
if err != nil { if err != nil {
@@ -259,7 +267,7 @@ func main() {
// wait for program to terminate // wait for program to terminate
signal.Notify(term, os.Kill) signal.Notify(term, syscall.SIGTERM)
signal.Notify(term, os.Interrupt) signal.Notify(term, os.Interrupt)
select { select {

17
misc.go
View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -41,23 +40,9 @@ func (a *AtomicBool) Set(val bool) {
atomic.StoreInt32(&a.flag, flag) atomic.StoreInt32(&a.flag, flag)
} }
/* Integer manipulation */
func toInt32(n uint32) int32 {
mask := uint32(1 << 31)
return int32(-(n & mask) + (n & ^mask))
}
func min(a, b uint) uint { func min(a, b uint) uint {
if a > b { if a > b {
return b return b
} }
return a return a
} }
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
}

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,14 +1,13 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"./tai64n"
"errors" "errors"
"git.zx2c4.com/wireguard-go/tai64n"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -58,6 +57,7 @@ func TestNoiseHandshake(t *testing.T) {
packet := make([]byte, 0, 256) packet := make([]byte, 0, 256)
writer := bytes.NewBuffer(packet) writer := bytes.NewBuffer(packet)
err = binary.Write(writer, binary.LittleEndian, msg1) err = binary.Write(writer, binary.LittleEndian, msg1)
assertNil(t, err)
peer := dev2.ConsumeMessageInitiation(msg1) peer := dev2.ConsumeMessageInitiation(msg1)
if peer == nil { if peer == nil {
t.Fatal("handshake failed at initiation message") t.Fatal("handshake failed at initiation message")

26
peer.go
View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -40,9 +39,9 @@ type Peer struct {
newHandshake *Timer newHandshake *Timer
zeroKeyMaterial *Timer zeroKeyMaterial *Timer
persistentKeepalive *Timer persistentKeepalive *Timer
handshakeAttempts uint handshakeAttempts uint32
needAnotherKeepalive bool needAnotherKeepalive AtomicBool
sentLastMinuteHandshake bool sentLastMinuteHandshake AtomicBool
} }
signals struct { signals struct {
@@ -54,7 +53,7 @@ type Peer struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work inbound chan *QueueInboundElement // sequential ordering of work
packetInNonceQueueIsAwaitingKey bool packetInNonceQueueIsAwaitingKey AtomicBool
} }
routines struct { routines struct {
@@ -171,7 +170,7 @@ func (peer *Peer) Start() {
} }
device := peer.device device := peer.device
device.log.Debug.Println(peer, ": Starting...") device.log.Debug.Println(peer, "- Starting...")
// reset routine state // reset routine state
@@ -241,7 +240,7 @@ func (peer *Peer) Stop() {
peer.routines.mutex.Lock() peer.routines.mutex.Lock()
defer peer.routines.mutex.Unlock() defer peer.routines.mutex.Unlock()
peer.device.log.Debug.Println(peer, ": Stopping...") peer.device.log.Debug.Println(peer, "- Stopping...")
peer.timersStop() peer.timersStop()
@@ -258,3 +257,14 @@ func (peer *Peer) Stop() {
peer.ZeroAndFlushAll() peer.ZeroAndFlushAll()
} }
var roamingDisabled bool
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
if roamingDisabled {
return
}
peer.mutex.Lock()
peer.endpoint = endpoint
peer.mutex.Unlock()
}

89
pools.go Normal file
View File

@@ -0,0 +1,89 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
import "sync"
func (device *Device) PopulatePools() {
if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool = &sync.Pool{
New: func() interface{} {
return new([MaxMessageSize]byte)
},
}
device.pool.inboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueInboundElement)
},
}
device.pool.outboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueOutboundElement)
},
}
} else {
device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
}
device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.inboundElementReuseChan <- new(QueueInboundElement)
}
device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
}
}
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
if PreallocatedBuffersPerPool == 0 {
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
} else {
return <-device.pool.messageBufferReuseChan
}
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool.Put(msg)
} else {
device.pool.messageBufferReuseChan <- msg
}
}
func (device *Device) GetInboundElement() *QueueInboundElement {
if PreallocatedBuffersPerPool == 0 {
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
} else {
return <-device.pool.inboundElementReuseChan
}
}
func (device *Device) PutInboundElement(msg *QueueInboundElement) {
if PreallocatedBuffersPerPool == 0 {
device.pool.inboundElementPool.Put(msg)
} else {
device.pool.inboundElementReuseChan <- msg
}
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
if PreallocatedBuffersPerPool == 0 {
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
} else {
return <-device.pool.outboundElementReuseChan
}
}
func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
if PreallocatedBuffersPerPool == 0 {
device.pool.outboundElementPool.Put(msg)
} else {
device.pool.outboundElementReuseChan <- msg
}
}

16
queueconstants.go Normal file
View File

@@ -0,0 +1,16 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package main
/* Implementation specific constants */
const (
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter
@@ -27,7 +27,7 @@ type RatelimiterEntry struct {
type Ratelimiter struct { type Ratelimiter struct {
mutex sync.RWMutex mutex sync.RWMutex
stop chan struct{} stopReset chan struct{}
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
} }
@@ -36,8 +36,8 @@ func (rate *Ratelimiter) Close() {
rate.mutex.Lock() rate.mutex.Lock()
defer rate.mutex.Unlock() defer rate.mutex.Unlock()
if rate.stop != nil { if rate.stopReset != nil {
close(rate.stop) close(rate.stopReset)
} }
} }
@@ -47,11 +47,11 @@ func (rate *Ratelimiter) Init() {
// stop any ongoing garbage collection routine // stop any ongoing garbage collection routine
if rate.stop != nil { if rate.stopReset != nil {
close(rate.stop) close(rate.stopReset)
} }
rate.stop = make(chan struct{}) rate.stopReset = make(chan struct{})
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry) rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry) rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
@@ -59,11 +59,16 @@ func (rate *Ratelimiter) Init() {
go func() { go func() {
ticker := time.NewTicker(time.Second) ticker := time.NewTicker(time.Second)
ticker.Stop()
for { for {
select { select {
case <-rate.stop: case _, ok := <-rate.stopReset:
ticker.Stop() ticker.Stop()
return if ok {
ticker = time.NewTicker(time.Second)
} else {
return
}
case <-ticker.C: case <-ticker.C:
func() { func() {
rate.mutex.Lock() rate.mutex.Lock()
@@ -84,6 +89,10 @@ func (rate *Ratelimiter) Init() {
} }
entry.mutex.Unlock() entry.mutex.Unlock()
} }
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
ticker.Stop()
}
}() }()
} }
} }
@@ -121,8 +130,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
rate.mutex.Lock() rate.mutex.Lock()
if IPv4 != nil { if IPv4 != nil {
rate.tableIPv4[keyIPv4] = entry rate.tableIPv4[keyIPv4] = entry
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
rate.stopReset <- struct{}{}
}
} else { } else {
rate.tableIPv6[keyIPv6] = entry rate.tableIPv6[keyIPv6] = entry
if len(rate.tableIPv6) == 1 && len(rate.tableIPv4) == 0 {
rate.stopReset <- struct{}{}
}
} }
rate.mutex.Unlock() rate.mutex.Unlock()
return true return true

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package ratelimiter package ratelimiter

View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -44,59 +43,29 @@ func (elem *QueueInboundElement) IsDropped() bool {
return atomic.LoadInt32(&elem.dropped) == AtomicTrue return atomic.LoadInt32(&elem.dropped) == AtomicTrue
} }
func (device *Device) addToInboundQueue( func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
queue chan *QueueInboundElement, select {
element *QueueInboundElement, case inboundQueue <- element:
) {
for {
select { select {
case queue <- element: case decryptionQueue <- element:
return return true
default: default:
select { element.Drop()
case old := <-queue: element.mutex.Unlock()
old.Drop() return false
default:
}
} }
default:
device.PutInboundElement(element)
return false
} }
} }
func (device *Device) addToDecryptionQueue( func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
queue chan *QueueInboundElement, select {
element *QueueInboundElement, case queue <- element:
) { return true
for { default:
select { return false
case queue <- element:
return
default:
select {
case old := <-queue:
// drop & release to potential consumer
old.Drop()
old.mutex.Unlock()
default:
}
}
}
}
func (device *Device) addToHandshakeQueue(
queue chan QueueHandshakeElement,
element QueueHandshakeElement,
) {
for {
select {
case queue <- element:
return
default:
select {
case elem := <-queue:
device.PutMessageBuffer(elem.buffer)
default:
}
}
} }
} }
@@ -105,12 +74,12 @@ func (device *Device) addToHandshakeQueue(
* NOTE: Not thread safe, but called by sequential receiver! * NOTE: Not thread safe, but called by sequential receiver!
*/ */
func (peer *Peer) keepKeyFreshReceiving() { func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake { if peer.timers.sentLastMinuteHandshake.Get() {
return return
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake = true peer.timers.sentLastMinuteHandshake.Set(true)
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@@ -125,11 +94,11 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped") logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
device.state.stopping.Done() device.net.stopping.Done()
}() }()
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting") logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - starting")
device.state.starting.Done() device.net.starting.Done()
// receive datagrams until conn is closed // receive datagrams until conn is closed
@@ -155,6 +124,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
} }
if err != nil { if err != nil {
device.PutMessageBuffer(buffer)
return return
} }
@@ -177,7 +147,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
// check size // check size
if len(packet) < MessageTransportType { if len(packet) < MessageTransportSize {
continue continue
} }
@@ -199,23 +169,23 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
} }
// create work element // create work element
peer := value.peer peer := value.peer
elem := &QueueInboundElement{ elem := device.GetInboundElement()
packet: packet, elem.packet = packet
buffer: buffer, elem.buffer = buffer
keypair: keypair, elem.keypair = keypair
dropped: AtomicFalse, elem.dropped = AtomicFalse
endpoint: endpoint, elem.endpoint = endpoint
} elem.counter = 0
elem.mutex = sync.Mutex{}
elem.mutex.Lock() elem.mutex.Lock()
// add to decryption queues // add to decryption queues
if peer.isRunning.Get() { if peer.isRunning.Get() {
device.addToDecryptionQueue(device.queue.decryption, elem) if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
device.addToInboundQueue(peer.queue.inbound, elem) buffer = device.GetMessageBuffer()
buffer = device.GetMessageBuffer() }
} }
continue continue
@@ -236,7 +206,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
} }
if okay { if okay {
device.addToHandshakeQueue( if (device.addToHandshakeQueue(
device.queue.handshake, device.queue.handshake,
QueueHandshakeElement{ QueueHandshakeElement{
msgType: msgType, msgType: msgType,
@@ -244,8 +214,9 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
packet: packet, packet: packet,
endpoint: endpoint, endpoint: endpoint,
}, },
) )) {
buffer = device.GetMessageBuffer() buffer = device.GetMessageBuffer()
}
} }
} }
} }
@@ -308,6 +279,7 @@ func (device *Device) RoutineDecryption() {
) )
if err != nil { if err != nil {
elem.Drop() elem.Drop()
device.PutMessageBuffer(elem.buffer)
} }
elem.mutex.Unlock() elem.mutex.Unlock()
} }
@@ -322,18 +294,26 @@ func (device *Device) RoutineHandshake() {
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
var elem QueueHandshakeElement
var ok bool
defer func() { defer func() {
logDebug.Println("Routine: handshake worker - stopped") logDebug.Println("Routine: handshake worker - stopped")
device.state.stopping.Done() device.state.stopping.Done()
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
}
}() }()
logDebug.Println("Routine: handshake worker - started") logDebug.Println("Routine: handshake worker - started")
device.state.starting.Done() device.state.starting.Done()
var elem QueueHandshakeElement
var ok bool
for { for {
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
elem.buffer = nil
}
select { select {
case elem, ok = <-device.queue.handshake: case elem, ok = <-device.queue.handshake:
case <-device.signals.stop: case <-device.signals.stop:
@@ -440,12 +420,9 @@ func (device *Device) RoutineHandshake() {
peer.timersAnyAuthenticatedPacketReceived() peer.timersAnyAuthenticatedPacketReceived()
// update endpoint // update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
peer.mutex.Lock() logDebug.Println(peer, "- Received handshake initiation")
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
logDebug.Println(peer, ": Received handshake initiation")
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
@@ -466,19 +443,16 @@ func (device *Device) RoutineHandshake() {
peer := device.ConsumeMessageResponse(&msg) peer := device.ConsumeMessageResponse(&msg)
if peer == nil { if peer == nil {
logInfo.Println( logInfo.Println(
"Recieved invalid response message from", "Received invalid response message from",
elem.endpoint.DstToString(), elem.endpoint.DstToString(),
) )
continue continue
} }
// update endpoint // update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
peer.mutex.Lock() logDebug.Println(peer, "- Received handshake response")
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
logDebug.Println(peer, ": Received handshake response")
// update timers // update timers
@@ -490,7 +464,7 @@ func (device *Device) RoutineHandshake() {
err = peer.BeginSymmetricSession() err = peer.BeginSymmetricSession()
if err != nil { if err != nil {
logError.Println(peer, ": Failed to derive keypair:", err) logError.Println(peer, "- Failed to derive keypair:", err)
continue continue
} }
@@ -512,23 +486,39 @@ func (peer *Peer) RoutineSequentialReceiver() {
logError := device.log.Error logError := device.log.Error
logDebug := device.log.Debug logDebug := device.log.Debug
var elem *QueueInboundElement
var ok bool
defer func() { defer func() {
logDebug.Println(peer, ": Routine: sequential receiver - stopped") logDebug.Println(peer, "- Routine: sequential receiver - stopped")
peer.routines.stopping.Done() peer.routines.stopping.Done()
if elem != nil {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
}
device.PutInboundElement(elem)
}
}() }()
logDebug.Println(peer, ": Routine: sequential receiver - started") logDebug.Println(peer, "- Routine: sequential receiver - started")
peer.routines.starting.Done() peer.routines.starting.Done()
for { for {
if elem != nil {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
}
device.PutInboundElement(elem)
elem = nil
}
select { select {
case <-peer.routines.stop: case <-peer.routines.stop:
return return
case elem, ok := <-peer.queue.inbound: case elem, ok = <-peer.queue.inbound:
if !ok { if !ok {
return return
@@ -544,15 +534,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for replay // check for replay
if !elem.keypair.replayFilter.ValidateCounter(elem.counter) { if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue continue
} }
// update endpoint // update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
peer.mutex.Lock()
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// check if using new keypair // check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) { if peer.ReceivedWithKeypair(elem.keypair) {
@@ -570,7 +557,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
// check for keepalive // check for keepalive
if len(elem.packet) == 0 { if len(elem.packet) == 0 {
logDebug.Println(peer, ": Receiving keepalive packet") logDebug.Println(peer, "- Receiving keepalive packet")
continue continue
} }
peer.timersDataReceived() peer.timersDataReceived()
@@ -642,10 +629,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
offset := MessageTransportOffsetContent offset := MessageTransportOffsetContent
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet))) atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write( _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
elem.buffer[:offset+len(elem.packet)],
offset)
device.PutMessageBuffer(elem.buffer)
if err != nil { if err != nil {
logError.Println("Failed to write packet to TUN device:", err) logError.Println("Failed to write packet to TUN device:", err)
} }

View File

@@ -1,12 +1,9 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package replay
/* Copyright (C) 2015-2017 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */
/* Implementation of RFC6479 /* Implementation of RFC6479
* https://tools.ietf.org/html/rfc6479 * https://tools.ietf.org/html/rfc6479
@@ -32,6 +29,13 @@ const (
BacktrackWords = CounterBitsTotal / _WordSize BacktrackWords = CounterBitsTotal / _WordSize
) )
func minUint64(a uint64, b uint64) uint64 {
if a > b {
return b
}
return a
}
type ReplayFilter struct { type ReplayFilter struct {
counter uint64 counter uint64
backtrack [BacktrackWords]uintptr backtrack [BacktrackWords]uintptr
@@ -42,8 +46,8 @@ func (filter *ReplayFilter) Init() {
filter.backtrack[0] = 0 filter.backtrack[0] = 0
} }
func (filter *ReplayFilter) ValidateCounter(counter uint64) bool { func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
if counter >= RejectAfterMessages { if counter >= limit {
return false return false
} }

View File

@@ -1,10 +1,9 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package replay
import ( import (
"testing" "testing"
@@ -15,6 +14,8 @@ import (
* *
*/ */
const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
func TestReplay(t *testing.T) { func TestReplay(t *testing.T) {
var filter ReplayFilter var filter ReplayFilter
@@ -23,7 +24,7 @@ func TestReplay(t *testing.T) {
testNumber := 0 testNumber := 0
T := func(n uint64, v bool) { T := func(n uint64, v bool) {
testNumber++ testNumber++
if filter.ValidateCounter(n) != v { if filter.ValidateCounter(n, RejectAfterMessages) != v {
t.Fatal("Test", testNumber, "failed", n, v) t.Fatal("Test", testNumber, "failed", n, v)
} }
} }

24
rwcancel/fdset_default.go Normal file
View File

@@ -0,0 +1,24 @@
// +build !freebsd
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
type fdSet struct {
fdset unix.FdSet
}
func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63)
fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits)
}
func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63)
return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
}

22
rwcancel/fdset_freebsd.go Normal file
View File

@@ -0,0 +1,22 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package rwcancel
import "golang.org/x/sys/unix"
type fdSet struct {
fdset unix.FdSet
}
func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63)
fdset.fdset.X__fds_bits[i/bits] |= 1 << uint(i%bits)
}
func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63)
return (fdset.fdset.X__fds_bits[i/bits] & (1 << uint(i%bits))) != 0
}

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package rwcancel package rwcancel
@@ -12,26 +12,6 @@ import (
"syscall" "syscall"
) )
type RWCancel struct {
fd int
closingReader *os.File
closingWriter *os.File
}
type fdSet struct {
fdset unix.FdSet
}
func (fdset *fdSet) set(i int) {
bits := 32 << (^uint(0) >> 63)
fdset.fdset.Bits[i/bits] |= 1 << uint(i%bits)
}
func (fdset *fdSet) check(i int) bool {
bits := 32 << (^uint(0) >> 63)
return (fdset.fdset.Bits[i/bits] & (1 << uint(i%bits))) != 0
}
func max(a, b int) int { func max(a, b int) int {
if a > b { if a > b {
return a return a
@@ -39,6 +19,12 @@ func max(a, b int) int {
return b return b
} }
type RWCancel struct {
fd int
closingReader *os.File
closingWriter *os.File
}
func NewRWCancel(fd int) (*RWCancel, error) { func NewRWCancel(fd int) (*RWCancel, error) {
err := unix.SetNonblock(fd, true) err := unix.SetNonblock(fd, true)
if err != nil { if err != nil {
@@ -54,15 +40,16 @@ func NewRWCancel(fd int) (*RWCancel, error) {
return &rwcancel, nil return &rwcancel, nil
} }
/* https://golang.org/src/crypto/rand/eagain.go */ func RetryAfterError(err error) bool {
func ErrorIsEAGAIN(err error) bool {
if pe, ok := err.(*os.PathError); ok { if pe, ok := err.(*os.PathError); ok {
if errno, ok := pe.Err.(syscall.Errno); ok && errno == syscall.EAGAIN { err = pe.Err
}
if errno, ok := err.(syscall.Errno); ok {
switch errno {
case syscall.EAGAIN, syscall.EINTR:
return true return true
} }
}
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EAGAIN {
return true
} }
return false return false
} }
@@ -100,7 +87,7 @@ func (rw *RWCancel) ReadyWrite() bool {
func (rw *RWCancel) Read(p []byte) (n int, err error) { func (rw *RWCancel) Read(p []byte) (n int, err error) {
for { for {
n, err := unix.Read(rw.fd, p) n, err := unix.Read(rw.fd, p)
if err == nil || !ErrorIsEAGAIN(err) { if err == nil || !RetryAfterError(err) {
return n, err return n, err
} }
if !rw.ReadyRead() { if !rw.ReadyRead() {
@@ -112,7 +99,7 @@ func (rw *RWCancel) Read(p []byte) (n int, err error) {
func (rw *RWCancel) Write(p []byte) (n int, err error) { func (rw *RWCancel) Write(p []byte) (n int, err error) {
for { for {
n, err := unix.Write(rw.fd, p) n, err := unix.Write(rw.fd, p)
if err == nil || !ErrorIsEAGAIN(err) { if err == nil || !RetryAfterError(err) {
return n, err return n, err
} }
if !rw.ReadyWrite() { if !rw.ReadyWrite() {

View File

@@ -1,6 +1,8 @@
// +build !linux
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package rwcancel package rwcancel

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package rwcancel package rwcancel

180
send.go
View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -53,10 +52,14 @@ type QueueOutboundElement struct {
} }
func (device *Device) NewOutboundElement() *QueueOutboundElement { func (device *Device) NewOutboundElement() *QueueOutboundElement {
return &QueueOutboundElement{ elem := device.GetOutboundElement()
dropped: AtomicFalse, elem.dropped = AtomicFalse
buffer: device.pool.messageBuffers.Get().(*[MaxMessageSize]byte), elem.buffer = device.GetMessageBuffer()
} elem.mutex = sync.Mutex{}
elem.nonce = 0
elem.keypair = nil
elem.peer = nil
return elem
} }
func (elem *QueueOutboundElement) Drop() { func (elem *QueueOutboundElement) Drop() {
@@ -67,10 +70,7 @@ func (elem *QueueOutboundElement) IsDropped() bool {
return atomic.LoadInt32(&elem.dropped) == AtomicTrue return atomic.LoadInt32(&elem.dropped) == AtomicTrue
} }
func addToOutboundQueue( func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
queue chan *QueueOutboundElement,
element *QueueOutboundElement,
) {
for { for {
select { select {
case queue <- element: case queue <- element:
@@ -78,53 +78,53 @@ func addToOutboundQueue(
default: default:
select { select {
case old := <-queue: case old := <-queue:
old.Drop() device.PutMessageBuffer(old.buffer)
device.PutOutboundElement(old)
default: default:
} }
} }
} }
} }
func addToEncryptionQueue( func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
queue chan *QueueOutboundElement, select {
element *QueueOutboundElement, case outboundQueue <- element:
) {
for {
select { select {
case queue <- element: case encryptionQueue <- element:
return return
default: default:
select { element.Drop()
case old := <-queue: element.peer.device.PutMessageBuffer(element.buffer)
// drop & release to potential consumer element.mutex.Unlock()
old.Drop()
old.mutex.Unlock()
default:
}
} }
default:
element.peer.device.PutMessageBuffer(element.buffer)
element.peer.device.PutOutboundElement(element)
} }
} }
/* Queues a keepalive if no packets are queued for peer /* Queues a keepalive if no packets are queued for peer
*/ */
func (peer *Peer) SendKeepalive() bool { func (peer *Peer) SendKeepalive() bool {
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey || !peer.isRunning.Get() { if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
return false return false
} }
elem := peer.device.NewOutboundElement() elem := peer.device.NewOutboundElement()
elem.packet = nil elem.packet = nil
select { select {
case peer.queue.nonce <- elem: case peer.queue.nonce <- elem:
peer.device.log.Debug.Println(peer, ": Sending keepalive packet") peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
return true return true
default: default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
return false return false
} }
} }
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry { if !isRetry {
peer.timers.handshakeAttempts = 0 atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
} }
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
@@ -142,11 +142,11 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
peer.handshake.lastSentHandshake = time.Now() peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
peer.device.log.Debug.Println(peer, ": Sending handshake initiation") peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
msg, err := peer.device.CreateMessageInitiation(peer) msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil { if err != nil {
peer.device.log.Error.Println(peer, ": Failed to create initiation message:", err) peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
return err return err
} }
@@ -161,7 +161,7 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err != nil { if err != nil {
peer.device.log.Error.Println(peer, ": Failed to send handshake initiation", err) peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
} }
peer.timersHandshakeInitiated() peer.timersHandshakeInitiated()
@@ -173,11 +173,11 @@ func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.lastSentHandshake = time.Now() peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
peer.device.log.Debug.Println(peer, ": Sending handshake response") peer.device.log.Debug.Println(peer, "- Sending handshake response")
response, err := peer.device.CreateMessageResponse(peer) response, err := peer.device.CreateMessageResponse(peer)
if err != nil { if err != nil {
peer.device.log.Error.Println(peer, ": Failed to create response message:", err) peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
return err return err
} }
@@ -189,7 +189,7 @@ func (peer *Peer) SendHandshakeResponse() error {
err = peer.BeginSymmetricSession() err = peer.BeginSymmetricSession()
if err != nil { if err != nil {
peer.device.log.Error.Println(peer, ": Failed to derive keypair:", err) peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
return err return err
} }
@@ -199,7 +199,7 @@ func (peer *Peer) SendHandshakeResponse() error {
err = peer.SendBuffer(packet) err = peer.SendBuffer(packet)
if err != nil { if err != nil {
peer.device.log.Error.Println(peer, ": Failed to send handshake response", err) peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
} }
return err return err
} }
@@ -243,8 +243,6 @@ func (peer *Peer) keepKeyFreshSending() {
*/ */
func (device *Device) RoutineReadFromTUN() { func (device *Device) RoutineReadFromTUN() {
elem := device.NewOutboundElement()
logDebug := device.log.Debug logDebug := device.log.Debug
logError := device.log.Error logError := device.log.Error
@@ -256,7 +254,14 @@ func (device *Device) RoutineReadFromTUN() {
logDebug.Println("Routine: TUN reader - started") logDebug.Println("Routine: TUN reader - started")
device.state.starting.Done() device.state.starting.Done()
var elem *QueueOutboundElement
for { for {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
elem = device.NewOutboundElement()
// read packet // read packet
@@ -264,8 +269,12 @@ func (device *Device) RoutineReadFromTUN() {
size, err := device.tun.device.Read(elem.buffer[:], offset) size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil { if err != nil {
logError.Println("Failed to read packet from TUN device:", err) if !device.isClosed.Get() {
device.Close() logError.Println("Failed to read packet from TUN device:", err)
device.Close()
}
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return return
} }
@@ -304,11 +313,11 @@ func (device *Device) RoutineReadFromTUN() {
// insert into nonce/pre-handshake queue // insert into nonce/pre-handshake queue
if peer.isRunning.Get() { if peer.isRunning.Get() {
if peer.queue.packetInNonceQueueIsAwaitingKey { if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
addToOutboundQueue(peer.queue.nonce, elem) addToNonceQueue(peer.queue.nonce, elem, device)
elem = device.NewOutboundElement() elem = nil
} }
} }
} }
@@ -332,28 +341,31 @@ func (peer *Peer) RoutineNonce() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() {
logDebug.Println(peer, ": Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey = false
peer.routines.stopping.Done()
}()
flush := func() { flush := func() {
for { for {
select { select {
case <-peer.queue.nonce: case elem := <-peer.queue.nonce:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
default: default:
return return
} }
} }
} }
defer func() {
flush()
logDebug.Println(peer, "- Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
peer.routines.stopping.Done()
}()
peer.routines.starting.Done() peer.routines.starting.Done()
logDebug.Println(peer, ": Routine: nonce worker - started") logDebug.Println(peer, "- Routine: nonce worker - started")
for { for {
NextPacket: NextPacket:
peer.queue.packetInNonceQueueIsAwaitingKey = false peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
select { select {
case <-peer.routines.stop: case <-peer.routines.stop:
@@ -381,7 +393,7 @@ func (peer *Peer) RoutineNonce() {
break break
} }
} }
peer.queue.packetInNonceQueueIsAwaitingKey = true peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
// no suitable key pair, request for new handshake // no suitable key pair, request for new handshake
@@ -394,21 +406,25 @@ func (peer *Peer) RoutineNonce() {
// wait for key to be established // wait for key to be established
logDebug.Println(peer, ": Awaiting keypair") logDebug.Println(peer, "- Awaiting keypair")
select { select {
case <-peer.signals.newKeypairArrived: case <-peer.signals.newKeypairArrived:
logDebug.Println(peer, ": Obtained awaited keypair") logDebug.Println(peer, "- Obtained awaited keypair")
case <-peer.signals.flushNonceQueue: case <-peer.signals.flushNonceQueue:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
flush() flush()
goto NextPacket goto NextPacket
case <-peer.routines.stop: case <-peer.routines.stop:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return return
} }
} }
peer.queue.packetInNonceQueueIsAwaitingKey = false peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
// populate work element // populate work element
@@ -419,6 +435,8 @@ func (peer *Peer) RoutineNonce() {
if elem.nonce >= RejectAfterMessages { if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages) atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
goto NextPacket goto NextPacket
} }
@@ -427,9 +445,7 @@ func (peer *Peer) RoutineNonce() {
elem.mutex.Lock() elem.mutex.Lock()
// add to parallel and sequential queue // add to parallel and sequential queue
addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
addToEncryptionQueue(device.queue.encryption, elem)
addToOutboundQueue(peer.queue.outbound, elem)
} }
} }
} }
@@ -446,6 +462,19 @@ func (device *Device) RoutineEncryption() {
logDebug := device.log.Debug logDebug := device.log.Debug
defer func() { defer func() {
for {
select {
case elem, ok := <-device.queue.encryption:
if ok && !elem.IsDropped() {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
elem.mutex.Unlock()
}
default:
goto out
}
}
out:
logDebug.Println("Routine: encryption worker - stopped") logDebug.Println("Routine: encryption worker - stopped")
device.state.stopping.Done() device.state.stopping.Done()
}() }()
@@ -488,11 +517,13 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16 // pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu)) mtu := int(atomic.LoadInt32(&device.tun.mtu))
rem := len(elem.packet) % PaddingMultiple lastUnit := len(elem.packet) % mtu
if rem > 0 { paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
for i := 0; i < PaddingMultiple-rem && len(elem.packet) < mtu; i++ { if paddedSize > mtu {
elem.packet = append(elem.packet, 0) paddedSize = mtu
} }
for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0)
} }
// encrypt content and release to consumer // encrypt content and release to consumer
@@ -519,13 +550,30 @@ func (peer *Peer) RoutineSequentialSender() {
device := peer.device device := peer.device
logDebug := device.log.Debug logDebug := device.log.Debug
logError := device.log.Error
defer func() { defer func() {
logDebug.Println(peer, ": Routine: sequential sender - stopped") for {
select {
case elem, ok := <-peer.queue.outbound:
if ok {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
elem.Drop()
}
device.PutOutboundElement(elem)
elem.mutex.Unlock()
}
default:
goto out
}
}
out:
logDebug.Println(peer, "- Routine: sequential sender - stopped")
peer.routines.stopping.Done() peer.routines.stopping.Done()
}() }()
logDebug.Println(peer, ": Routine: sequential sender - started") logDebug.Println(peer, "- Routine: sequential sender - started")
peer.routines.starting.Done() peer.routines.starting.Done()
@@ -543,6 +591,7 @@ func (peer *Peer) RoutineSequentialSender() {
elem.mutex.Lock() elem.mutex.Lock()
if elem.IsDropped() { if elem.IsDropped() {
device.PutOutboundElement(elem)
continue continue
} }
@@ -554,8 +603,9 @@ func (peer *Peer) RoutineSequentialSender() {
length := uint64(len(elem.packet)) length := uint64(len(elem.packet))
err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
if err != nil { if err != nil {
logDebug.Println("Failed to send authenticated packet to peer", peer) logError.Println(peer, "- Failed to send data packet", err)
continue continue
} }
atomic.AddUint64(&peer.stats.txBytes, length) atomic.AddUint64(&peer.stats.txBytes, length)

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package tai64n package tai64n

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package tai64n package tai64n

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2015-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* *
* This is based heavily on timers.c from the kernel implementation. * This is based heavily on timers.c from the kernel implementation.
*/ */
@@ -20,7 +20,7 @@ import (
type Timer struct { type Timer struct {
timer *time.Timer timer *time.Timer
modifyingLock sync.Mutex modifyingLock sync.RWMutex
runningLock sync.Mutex runningLock sync.Mutex
isPending bool isPending bool
} }
@@ -67,12 +67,18 @@ func (timer *Timer) DelSync() {
timer.runningLock.Unlock() timer.runningLock.Unlock()
} }
func (timer *Timer) IsPending() bool {
timer.modifyingLock.RLock()
defer timer.modifyingLock.RUnlock()
return timer.isPending
}
func (peer *Peer) timersActive() bool { func (peer *Peer) timersActive() bool {
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0 return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
} }
func expiredRetransmitHandshake(peer *Peer) { func expiredRetransmitHandshake(peer *Peer) {
if peer.timers.handshakeAttempts > MaxTimerHandshakes { if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2) peer.device.log.Debug.Printf("%s: Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
if peer.timersActive() { if peer.timersActive() {
@@ -87,12 +93,12 @@ func expiredRetransmitHandshake(peer *Peer) {
/* We set a timer for destroying any residue that might be left /* We set a timer for destroying any residue that might be left
* of a partial exchange. * of a partial exchange.
*/ */
if peer.timersActive() && !peer.timers.zeroKeyMaterial.isPending { if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3) peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
} }
} else { } else {
peer.timers.handshakeAttempts++ atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), peer.timers.handshakeAttempts+1) peer.device.log.Debug.Printf("%s: Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */ /* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.mutex.Lock() peer.mutex.Lock()
@@ -107,8 +113,8 @@ func expiredRetransmitHandshake(peer *Peer) {
func expiredSendKeepalive(peer *Peer) { func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive() peer.SendKeepalive()
if peer.timers.needAnotherKeepalive { if peer.timers.needAnotherKeepalive.Get() {
peer.timers.needAnotherKeepalive = false peer.timers.needAnotherKeepalive.Set(false)
if peer.timersActive() { if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} }
@@ -128,7 +134,7 @@ func expiredNewHandshake(peer *Peer) {
} }
func expiredZeroKeyMaterial(peer *Peer) { func expiredZeroKeyMaterial(peer *Peer) {
peer.device.log.Debug.Printf(":%s Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds())) peer.device.log.Debug.Printf("%s: Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll() peer.ZeroAndFlushAll()
} }
@@ -140,7 +146,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.isPending { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
} }
} }
@@ -148,10 +154,10 @@ func (peer *Peer) timersDataSent() {
/* Should be called after an authenticated data packet is received. */ /* Should be called after an authenticated data packet is received. */
func (peer *Peer) timersDataReceived() { func (peer *Peer) timersDataReceived() {
if peer.timersActive() { if peer.timersActive() {
if !peer.timers.sendKeepalive.isPending { if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout) peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else { } else {
peer.timers.needAnotherKeepalive = true peer.timers.needAnotherKeepalive.Set(true)
} }
} }
} }
@@ -182,8 +188,8 @@ func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() { if peer.timersActive() {
peer.timers.retransmitHandshake.Del() peer.timers.retransmitHandshake.Del()
} }
peer.timers.handshakeAttempts = 0 atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake = false peer.timers.sentLastMinuteHandshake.Set(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano()) atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
} }
@@ -207,9 +213,9 @@ func (peer *Peer) timersInit() {
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake) peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial) peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive) peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
peer.timers.handshakeAttempts = 0 atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake = false peer.timers.sentLastMinuteHandshake.Set(false)
peer.timers.needAnotherKeepalive = false peer.timers.needAnotherKeepalive.Set(false)
} }
func (peer *Peer) timersStop() { func (peer *Peer) timersStop() {

32
tun.go
View File

@@ -1,45 +1,28 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"os" "git.zx2c4.com/wireguard-go/tun"
"sync/atomic" "sync/atomic"
) )
const DefaultMTU = 1420 const DefaultMTU = 1420
type TUNEvent int
const (
TUNEventUp = 1 << iota
TUNEventDown
TUNEventMTUUpdate
)
type TUNDevice interface {
File() *os.File // returns the file descriptor of the device
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
MTU() (int, error) // returns the MTU of the device
Name() (string, error) // fetches and returns the current name
Events() chan TUNEvent // returns a constant channel of events related to the device
Close() error // stops the device and closes the event channel
}
func (device *Device) RoutineTUNEventReader() { func (device *Device) RoutineTUNEventReader() {
setUp := false setUp := false
logDebug := device.log.Debug
logInfo := device.log.Info logInfo := device.log.Info
logError := device.log.Error logError := device.log.Error
logDebug.Println("Routine: event worker - started")
device.state.starting.Done() device.state.starting.Done()
for event := range device.tun.device.Events() { for event := range device.tun.device.Events() {
if event&TUNEventMTUUpdate != 0 { if event&tun.TUNEventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
old := atomic.LoadInt32(&device.tun.mtu) old := atomic.LoadInt32(&device.tun.mtu)
if err != nil { if err != nil {
@@ -54,18 +37,19 @@ func (device *Device) RoutineTUNEventReader() {
} }
} }
if event&TUNEventUp != 0 && !setUp { if event&tun.TUNEventUp != 0 && !setUp {
logInfo.Println("Interface set up") logInfo.Println("Interface set up")
setUp = true setUp = true
device.Up() device.Up()
} }
if event&TUNEventDown != 0 && setUp { if event&tun.TUNEventDown != 0 && setUp {
logInfo.Println("Interface set down") logInfo.Println("Interface set down")
setUp = false setUp = false
device.Down() device.Down()
} }
} }
logDebug.Println("Routine: event worker - stopped")
device.state.stopping.Done() device.state.stopping.Done()
} }

26
tun/tun.go Normal file
View File

@@ -0,0 +1,26 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package tun
import "os"
type TUNEvent int
const (
TUNEventUp = 1 << iota
TUNEventDown
TUNEventMTUUpdate
)
type TUNDevice interface {
File() *os.File // returns the file descriptor of the device
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
MTU() (int, error) // returns the MTU of the device
Name() (string, error) // fetches and returns the current name
Events() chan TUNEvent // returns a constant channel of events related to the device
Close() error // stops the device and closes the event channel
}

View File

@@ -1,22 +1,22 @@
// +build !ios
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package tun
import ( import (
"./rwcancel"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"time" "syscall"
"unsafe" "unsafe"
) )
@@ -35,25 +35,78 @@ type sockaddrCtl struct {
scReserved [5]uint32 scReserved [5]uint32
} }
// NativeTun is a hack to work around the first 4 bytes "packet type nativeTun struct {
// information" because there doesn't seem to be an IFF_NO_PI for darwin. name string
type NativeTun struct { fd *os.File
name string rwcancel *rwcancel.RWCancel
fd *os.File events chan TUNEvent
rwcancel *rwcancel.RWCancel errors chan error
mtu int routeSocket int
events chan TUNEvent
errors chan error
statusListenersShutdown chan struct{}
} }
var sockaddrCtlSize uintptr = 32 var sockaddrCtlSize uintptr = 32
func CreateTUN(name string) (TUNDevice, error) { func (tun *nativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
statusMTU int
)
defer close(tun.events)
data := make([]byte, os.Getpagesize())
for {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
goto retry
}
tun.errors <- err
return
}
if n < 14 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
}
// Up / Down event
up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- TUNEventUp
}
if up != statusUp && !up {
tun.events <- TUNEventDown
}
statusUp = up
// MTU changes
if iface.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate
}
statusMTU = iface.MTU
}
}
func CreateTUN(name string, mtu int) (TUNDevice, error) {
ifIndex := -1 ifIndex := -1
if name != "utun" { if name != "utun" {
fmt.Sscanf(name, "utun%d", &ifIndex) _, err := fmt.Sscanf(name, "utun%d", &ifIndex)
if ifIndex < 0 { if err != nil || ifIndex < 0 {
return nil, fmt.Errorf("Interface name must be utun[0-9]*") return nil, fmt.Errorf("Interface name must be utun[0-9]*")
} }
} }
@@ -103,29 +156,39 @@ func CreateTUN(name string) (TUNDevice, error) {
return nil, fmt.Errorf("SYS_CONNECT: %v", errno) return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
} }
tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), "")) tun, err := CreateTUNFromFile(os.NewFile(uintptr(fd), ""), mtu)
if err == nil && name == "utun" { if err == nil && name == "utun" {
fname := os.Getenv("WG_DARWIN_UTUN_NAME_FILE") fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" { if fname != "" {
ioutil.WriteFile(fname, []byte(tun.(*NativeTun).name+"\n"), 0400) ioutil.WriteFile(fname, []byte(tun.(*nativeTun).name+"\n"), 0400)
} }
} }
return tun, err return tun, err
} }
func CreateTUNFromFile(file *os.File) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
tun := &NativeTun{ tun := &nativeTun{
fd: file, fd: file,
mtu: 1500, events: make(chan TUNEvent, 10),
events: make(chan TUNEvent, 10), errors: make(chan error, 1),
errors: make(chan error, 1),
statusListenersShutdown: make(chan struct{}),
} }
_, err := tun.Name() name, err := tun.Name()
if err != nil {
tun.fd.Close()
return nil, err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil { if err != nil {
tun.fd.Close() tun.fd.Close()
return nil, err return nil, err
@@ -137,46 +200,15 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
return nil, err return nil, err
} }
// TODO: Fix this very naive implementation tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
go func(tun *NativeTun) { if err != nil {
var ( tun.fd.Close()
statusUp bool return nil, err
statusMTU int }
)
for { go tun.routineRouteListener(tunIfindex)
intr, err := net.InterfaceByName(tun.name)
if err != nil {
tun.errors <- err
return
}
// Up / Down event err = tun.setMTU(mtu)
up := (intr.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- TUNEventUp
}
if up != statusUp && !up {
tun.events <- TUNEventDown
}
statusUp = up
// MTU changes
if intr.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate
}
statusMTU = intr.MTU
select {
case <-time.After(time.Second / 10):
case <-tun.statusListenersShutdown:
return
}
}
}(tun)
// set default MTU
err = tun.setMTU(DefaultMTU)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, err return nil, err
@@ -185,7 +217,7 @@ func CreateTUNFromFile(file *os.File) (TUNDevice, error) {
return tun, nil return tun, nil
} }
func (tun *NativeTun) Name() (string, error) { func (tun *nativeTun) Name() (string, error) {
var ifName struct { var ifName struct {
name [16]byte name [16]byte
@@ -208,15 +240,15 @@ func (tun *NativeTun) Name() (string, error) {
return tun.name, nil return tun.name, nil
} }
func (tun *NativeTun) File() *os.File { func (tun *nativeTun) File() *os.File {
return tun.fd return tun.fd
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *nativeTun) Events() chan TUNEvent {
return tun.events return tun.events
} }
func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) { func (tun *nativeTun) doRead(buff []byte, offset int) (int, error) {
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
@@ -230,10 +262,10 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *nativeTun) Read(buff []byte, offset int) (int, error) {
for { for {
n, err := tun.doRead(buff, offset) n, err := tun.doRead(buff, offset)
if err == nil || !rwcancel.ErrorIsEAGAIN(err) { if err == nil || !rwcancel.RetryAfterError(err) {
return n, err return n, err
} }
if !tun.rwcancel.ReadyRead() { if !tun.rwcancel.ReadyRead() {
@@ -242,7 +274,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *nativeTun) Write(buff []byte, offset int) (int, error) {
// reserve space for header // reserve space for header
@@ -265,18 +297,27 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.fd.Write(buff) return tun.fd.Write(buff)
} }
func (tun *NativeTun) Close() error { func (tun *nativeTun) Close() error {
close(tun.statusListenersShutdown) var err3 error
err1 := tun.rwcancel.Cancel() err1 := tun.rwcancel.Cancel()
err2 := tun.fd.Close() err2 := tun.fd.Close()
close(tun.events) if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err3 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil {
close(tun.events)
}
if err1 != nil { if err1 != nil {
return err1 return err1
} }
return err2 if err2 != nil {
return err2
}
return err3
} }
func (tun *NativeTun) setMTU(n int) error { func (tun *nativeTun) setMTU(n int) error {
// open datagram socket // open datagram socket
@@ -298,7 +339,7 @@ func (tun *NativeTun) setMTU(n int) error {
var ifr [32]byte var ifr [32]byte
copy(ifr[:], tun.name) copy(ifr[:], tun.name)
binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
uintptr(fd), uintptr(fd),
@@ -307,13 +348,13 @@ func (tun *NativeTun) setMTU(n int) error {
) )
if errno != 0 { if errno != 0 {
return fmt.Errorf("Failed to set MTU on %s", tun.name) return fmt.Errorf("failed to set MTU on %s", tun.name)
} }
return nil return nil
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *nativeTun) MTU() (int, error) {
// open datagram socket // open datagram socket
@@ -340,14 +381,8 @@ func (tun *NativeTun) MTU() (int, error) {
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return 0, fmt.Errorf("Failed to get MTU on %s", tun.name) return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
} }
// convert result to signed 32-bit int return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) {
return int(val-(1<<31)) - (1 << 31), nil
}
return int(val), nil
} }

510
tun/tun_freebsd.go Normal file
View File

@@ -0,0 +1,510 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"errors"
"fmt"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"net"
"os"
"syscall"
"unsafe"
)
// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const _TUNSIFHEAD = 0x80047460
const _TUNSIFMODE = 0x8004745e
const _TUNSIFPID = 0x2000745f
// Iface status string max len
const _IFSTATMAX = 800
const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
// structure for iface requests with a pointer
type ifreq_ptr struct {
Name [unix.IFNAMSIZ]byte
Data uintptr
Pad0 [24 - SIZEOF_UINTPTR]byte
}
// Structure for iface mtu get/set ioctls
type ifreq_mtu struct {
Name [unix.IFNAMSIZ]byte
MTU uint32
Pad0 [12]byte
}
// Structure for interface status request ioctl
type ifstat struct {
IfsName [unix.IFNAMSIZ]byte
Ascii [_IFSTATMAX]byte
}
type nativeTun struct {
name string
fd *os.File
rwcancel *rwcancel.RWCancel
events chan TUNEvent
errors chan error
routeSocket int
}
func (tun *nativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
statusMTU int
)
defer close(tun.events)
data := make([]byte, os.Getpagesize())
for {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
goto retry
}
tun.errors <- err
return
}
if n < 14 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[12 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
}
// Up / Down event
up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- TUNEventUp
}
if up != statusUp && !up {
tun.events <- TUNEventDown
}
statusUp = up
// MTU changes
if iface.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate
}
statusMTU = iface.MTU
}
}
func tunName(fd uintptr) (string, error) {
//Terrible hack to make up for freebsd not having a TUNGIFNAME
//First, make sure the tun pid matches this proc's pid
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(_TUNSIFPID),
uintptr(0),
)
if errno != 0 {
return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error())
}
// Open iface control socket
confd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return "", err
}
defer unix.Close(confd)
procPid := os.Getpid()
//Try to find interface with matching PID
for i := 1; ; i++ {
iface, _ := net.InterfaceByIndex(i)
if err != nil || iface == nil {
break
}
// Structs for getting data in and out of SIOCGIFSTATUS ioctl
var ifstatus ifstat
copy(ifstatus.IfsName[:], iface.Name)
// Make the syscall to get the status string
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCGIFSTATUS),
uintptr(unsafe.Pointer(&ifstatus)),
)
if errno != 0 {
continue
}
nullStr := ifstatus.Ascii[:]
i := bytes.IndexByte(nullStr, 0)
if i < 1 {
continue
}
statStr := string(nullStr[:i])
var pidNum int = 0
// Finally get the owning PID
// Format string taken from sys/net/if_tun.c
_, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum)
if err != nil {
continue
}
if pidNum == procPid {
return iface.Name, nil
}
}
return "", nil
}
// Destroy a named system interface
func tunDestroy(name string) error {
// open control socket
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
// do ioctl call
var ifr [32]byte
copy(ifr[:], name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCIFDESTROY),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error())
}
return nil
}
func CreateTUN(name string, mtu int) (TUNDevice, error) {
if len(name) > unix.IFNAMSIZ-1 {
return nil, errors.New("interface name too long")
}
// See if interface already exists
iface, _ := net.InterfaceByName(name)
if iface != nil {
return nil, fmt.Errorf("interface %s already exists", name)
}
tunfile, err := os.OpenFile("/dev/tun", unix.O_RDWR, 0)
if err != nil {
return nil, err
}
tunfd := tunfile.Fd()
assignedName, err := tunName(tunfd)
if err != nil {
tunfile.Close()
return nil, err
}
// Enable ifhead mode, otherwise tun will complain if it gets a non-AF_INET packet
ifheadmode := 1
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(tunfd),
uintptr(_TUNSIFHEAD),
uintptr(unsafe.Pointer(&ifheadmode)),
)
if errno != 0 {
return nil, fmt.Errorf("error %s", errno.Error())
}
// Rename tun interface
// Open control socket
confd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return nil, err
}
defer unix.Close(confd)
// set up struct for iface rename
var newnp [unix.IFNAMSIZ]byte
copy(newnp[:], name)
var ifr ifreq_ptr
copy(ifr.Name[:], assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
//do actual ioctl to rename iface
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCSIFNAME),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
tunfile.Close()
tunDestroy(name)
return nil, fmt.Errorf("failed to rename %s to %s: %s", assignedName, name, errno.Error())
}
return CreateTUNFromFile(tunfile, mtu)
}
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
tun := &nativeTun{
fd: file,
events: make(chan TUNEvent, 10),
errors: make(chan error, 1),
}
name, err := tun.Name()
if err != nil {
tun.fd.Close()
return nil, err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil {
tun.fd.Close()
return nil, err
}
tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd()))
if err != nil {
tun.fd.Close()
return nil, err
}
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.fd.Close()
return nil, err
}
go tun.routineRouteListener(tunIfindex)
err = tun.setMTU(mtu)
if err != nil {
tun.Close()
return nil, err
}
return tun, nil
}
func (tun *nativeTun) Name() (string, error) {
name, err := tunName(tun.fd.Fd())
if err != nil {
return "", err
}
tun.name = name
return name, nil
}
func (tun *nativeTun) File() *os.File {
return tun.fd
}
func (tun *nativeTun) Events() chan TUNEvent {
return tun.events
}
func (tun *nativeTun) doRead(buff []byte, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
buff := buff[offset-4:]
n, err := tun.fd.Read(buff[:])
if n < 4 {
return 0, err
}
return n - 4, err
}
}
func (tun *nativeTun) Read(buff []byte, offset int) (int, error) {
for {
n, err := tun.doRead(buff, offset)
if err == nil || !rwcancel.RetryAfterError(err) {
return n, err
}
if !tun.rwcancel.ReadyRead() {
return 0, errors.New("tun device closed")
}
}
}
func (tun *nativeTun) Write(buff []byte, offset int) (int, error) {
// reserve space for header
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
}
// write
return tun.fd.Write(buff)
}
func (tun *nativeTun) Close() error {
var err4 error
err1 := tun.rwcancel.Cancel()
err2 := tun.fd.Close()
err3 := tunDestroy(tun.name)
if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err4 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil {
close(tun.events)
}
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
if err3 != nil {
return err3
}
return err4
}
func (tun *nativeTun) setMTU(n int) error {
// open datagram socket
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
// do ioctl call
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name)
}
return nil
}
func (tun *nativeTun) MTU() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
// do ioctl call
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}

View File

@@ -1,27 +1,24 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
/* Copyright 2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. */ package tun
package main
/* Implementation of the TUN device interface for linux /* Implementation of the TUN device interface for linux
*/ */
import ( import (
"./rwcancel"
"bytes" "bytes"
"encoding/binary"
"errors" "errors"
"fmt" "fmt"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net" "net"
"os" "os"
"strconv" "strconv"
"sync"
"time" "time"
"unsafe" "unsafe"
) )
@@ -31,25 +28,26 @@ const (
ifReqSize = unix.IFNAMSIZ + 64 ifReqSize = unix.IFNAMSIZ + 64
) )
type NativeTun struct { type nativeTun struct {
fd *os.File fd *os.File
fdCancel *rwcancel.RWCancel fdCancel *rwcancel.RWCancel
index int32 // if index index int32 // if index
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan TUNEvent // device related events events chan TUNEvent // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was pased IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{} statusListenersShutdown chan struct{}
} }
func (tun *NativeTun) File() *os.File { func (tun *nativeTun) File() *os.File {
return tun.fd return tun.fd
} }
func (tun *NativeTun) RoutineHackListener() { func (tun *nativeTun) routineHackListener() {
defer tun.hackListenerClosed.Unlock()
/* This is needed for the detection to work across network namespaces /* This is needed for the detection to work across network namespaces
* If you are reading this and know a better method, please get in touch. * If you are reading this and know a better method, please get in touch.
*/ */
@@ -65,7 +63,7 @@ func (tun *NativeTun) RoutineHackListener() {
return return
} }
select { select {
case <-time.After(time.Second / 10): case <-time.After(time.Second):
case <-tun.statusListenersShutdown: case <-tun.statusListenersShutdown:
return return
} }
@@ -88,8 +86,12 @@ func createNetlinkSocket() (int, error) {
return sock, nil return sock, nil
} }
func (tun *NativeTun) RoutineNetlinkListener() { func (tun *nativeTun) routineNetlinkListener() {
defer unix.Close(tun.netlinkSock) defer func() {
unix.Close(tun.netlinkSock)
tun.hackListenerClosed.Lock()
close(tun.events)
}()
for msg := make([]byte, 1<<16); ; { for msg := make([]byte, 1<<16); ; {
@@ -97,7 +99,7 @@ func (tun *NativeTun) RoutineNetlinkListener() {
var msgn int var msgn int
for { for {
msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0) msgn, _, _, _, err = unix.Recvmsg(tun.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.ErrorIsEAGAIN(err) { if err == nil || !rwcancel.RetryAfterError(err) {
break break
} }
if !tun.netlinkCancel.ReadyRead() { if !tun.netlinkCancel.ReadyRead() {
@@ -154,7 +156,7 @@ func (tun *NativeTun) RoutineNetlinkListener() {
} }
} }
func (tun *NativeTun) isUp() (bool, error) { func (tun *nativeTun) isUp() (bool, error) {
inter, err := net.InterfaceByName(tun.name) inter, err := net.InterfaceByName(tun.name)
return inter.Flags&net.FlagUp != 0, err return inter.Flags&net.FlagUp != 0, err
} }
@@ -188,11 +190,10 @@ func getIFIndex(name string) (int32, error) {
return 0, errno return 0, errno
} }
index := binary.LittleEndian.Uint32(ifr[unix.IFNAMSIZ:]) return *(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])), nil
return toInt32(index), nil
} }
func (tun *NativeTun) setMTU(n int) error { func (tun *nativeTun) setMTU(n int) error {
// open datagram socket // open datagram socket
@@ -212,7 +213,7 @@ func (tun *NativeTun) setMTU(n int) error {
var ifr [ifReqSize]byte var ifr [ifReqSize]byte
copy(ifr[:], tun.name) copy(ifr[:], tun.name)
binary.LittleEndian.PutUint32(ifr[16:20], uint32(n)) *(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
uintptr(fd), uintptr(fd),
@@ -221,13 +222,13 @@ func (tun *NativeTun) setMTU(n int) error {
) )
if errno != 0 { if errno != 0 {
return errors.New("Failed to set MTU of TUN device") return errors.New("failed to set MTU of TUN device")
} }
return nil return nil
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *nativeTun) MTU() (int, error) {
// open datagram socket // open datagram socket
@@ -254,19 +255,13 @@ func (tun *NativeTun) MTU() (int, error) {
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return 0, errors.New("Failed to get MTU of TUN device: " + strconv.FormatInt(int64(errno), 10)) return 0, errors.New("failed to get MTU of TUN device: " + strconv.FormatInt(int64(errno), 10))
} }
// convert result to signed 32-bit int return int(*(*int32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ]))), nil
val := binary.LittleEndian.Uint32(ifr[16:20])
if val >= (1 << 31) {
return int(toInt32(val)), nil
}
return int(val), nil
} }
func (tun *NativeTun) Name() (string, error) { func (tun *nativeTun) Name() (string, error) {
var ifr [ifReqSize]byte var ifr [ifReqSize]byte
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
@@ -276,7 +271,7 @@ func (tun *NativeTun) Name() (string, error) {
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return "", errors.New("Failed to get name of TUN device: " + strconv.FormatInt(int64(errno), 10)) return "", errors.New("failed to get name of TUN device: " + strconv.FormatInt(int64(errno), 10))
} }
nullStr := ifr[:] nullStr := ifr[:]
i := bytes.IndexByte(nullStr, 0) i := bytes.IndexByte(nullStr, 0)
@@ -287,7 +282,7 @@ func (tun *NativeTun) Name() (string, error) {
return tun.name, nil return tun.name, nil
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *nativeTun) Write(buff []byte, offset int) (int, error) {
if tun.nopi { if tun.nopi {
buff = buff[offset:] buff = buff[offset:]
@@ -315,7 +310,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
return tun.fd.Write(buff) return tun.fd.Write(buff)
} }
func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) { func (tun *nativeTun) doRead(buff []byte, offset int) (int, error) {
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
@@ -333,10 +328,10 @@ func (tun *NativeTun) doRead(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *nativeTun) Read(buff []byte, offset int) (int, error) {
for { for {
n, err := tun.doRead(buff, offset) n, err := tun.doRead(buff, offset)
if err == nil || !rwcancel.ErrorIsEAGAIN(err) { if err == nil || !rwcancel.RetryAfterError(err) {
return n, err return n, err
} }
if !tun.fdCancel.ReadyRead() { if !tun.fdCancel.ReadyRead() {
@@ -345,19 +340,22 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *nativeTun) Events() chan TUNEvent {
return tun.events return tun.events
} }
func (tun *NativeTun) Close() error { func (tun *nativeTun) Close() error {
var err1 error var err1 error
close(tun.statusListenersShutdown) if tun.statusListenersShutdown != nil {
if tun.netlinkCancel != nil { close(tun.statusListenersShutdown)
err1 = tun.netlinkCancel.Cancel() if tun.netlinkCancel != nil {
err1 = tun.netlinkCancel.Cancel()
}
} else if tun.events != nil {
close(tun.events)
} }
err2 := tun.fd.Close() err2 := tun.fd.Close()
err3 := tun.fdCancel.Cancel() err3 := tun.fdCancel.Cancel()
close(tun.events)
if err1 != nil { if err1 != nil {
return err1 return err1
@@ -368,7 +366,7 @@ func (tun *NativeTun) Close() error {
return err3 return err3
} }
func CreateTUN(name string) (TUNDevice, error) { func CreateTUN(name string, mtu int) (TUNDevice, error) {
// open clone device // open clone device
@@ -398,7 +396,7 @@ func CreateTUN(name string) (TUNDevice, error) {
return nil, errors.New("interface name too long") return nil, errors.New("interface name too long")
} }
copy(ifr[:], nameBytes) copy(ifr[:], nameBytes)
binary.LittleEndian.PutUint16(ifr[16:], flags) *(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
@@ -410,12 +408,12 @@ func CreateTUN(name string) (TUNDevice, error) {
return nil, errno return nil, errno
} }
return CreateTUNFromFile(fd) return CreateTUNFromFile(fd, mtu)
} }
func CreateTUNFromFile(fd *os.File) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
tun := &NativeTun{ tun := &nativeTun{
fd: fd, fd: file,
events: make(chan TUNEvent, 5), events: make(chan TUNEvent, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}), statusListenersShutdown: make(chan struct{}),
@@ -423,7 +421,7 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
} }
var err error var err error
tun.fdCancel, err = rwcancel.NewRWCancel(int(fd.Fd())) tun.fdCancel, err = rwcancel.NewRWCancel(int(file.Fd()))
if err != nil { if err != nil {
tun.fd.Close() tun.fd.Close()
return nil, err return nil, err
@@ -453,12 +451,11 @@ func CreateTUNFromFile(fd *os.File) (TUNDevice, error) {
return nil, err return nil, err
} }
go tun.RoutineNetlinkListener() tun.hackListenerClosed.Lock()
go tun.RoutineHackListener() // cross namespace go tun.routineNetlinkListener()
go tun.routineHackListener() // cross namespace
// set default MTU err = tun.setMTU(mtu)
err = tun.setMTU(DefaultMTU)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, err return nil, err

348
tun/tun_openbsd.go Normal file
View File

@@ -0,0 +1,348 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"errors"
"fmt"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"io/ioutil"
"net"
"os"
"syscall"
"unsafe"
)
// Structure for iface mtu get/set ioctls
type ifreq_mtu struct {
Name [unix.IFNAMSIZ]byte
MTU uint32
Pad0 [12]byte
}
const _TUNSIFMODE = 0x8004745d
type nativeTun struct {
name string
fd *os.File
rwcancel *rwcancel.RWCancel
events chan TUNEvent
errors chan error
routeSocket int
}
func (tun *nativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
statusMTU int
)
defer close(tun.events)
data := make([]byte, os.Getpagesize())
for {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
goto retry
}
tun.errors <- err
return
}
if n < 8 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return
}
// Up / Down event
up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up {
tun.events <- TUNEventUp
}
if up != statusUp && !up {
tun.events <- TUNEventDown
}
statusUp = up
// MTU changes
if iface.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate
}
statusMTU = iface.MTU
}
}
func errorIsEBUSY(err error) bool {
if pe, ok := err.(*os.PathError); ok {
err = pe.Err
}
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EBUSY {
return true
}
return false
}
func CreateTUN(name string, mtu int) (TUNDevice, error) {
ifIndex := -1
if name != "tun" {
_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
if err != nil || ifIndex < 0 {
return nil, fmt.Errorf("Interface name must be tun[0-9]*")
}
}
var tunfile *os.File
var err error
if ifIndex != -1 {
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
} else {
for ifIndex = 0; ifIndex < 256; ifIndex += 1 {
tunfile, err = os.OpenFile(fmt.Sprintf("/dev/tun%d", ifIndex), unix.O_RDWR, 0)
if err == nil || !errorIsEBUSY(err) {
break
}
}
}
if err != nil {
return nil, err
}
tun, err := CreateTUNFromFile(tunfile, mtu)
if err == nil && name == "tun" {
fname := os.Getenv("WG_TUN_NAME_FILE")
if fname != "" {
ioutil.WriteFile(fname, []byte(tun.(*nativeTun).name+"\n"), 0400)
}
}
return tun, err
}
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
tun := &nativeTun{
fd: file,
events: make(chan TUNEvent, 10),
errors: make(chan error, 1),
}
name, err := tun.Name()
if err != nil {
tun.fd.Close()
return nil, err
}
tunIfindex, err := func() (int, error) {
iface, err := net.InterfaceByName(name)
if err != nil {
return -1, err
}
return iface.Index, nil
}()
if err != nil {
tun.fd.Close()
return nil, err
}
tun.rwcancel, err = rwcancel.NewRWCancel(int(file.Fd()))
if err != nil {
tun.fd.Close()
return nil, err
}
tun.routeSocket, err = unix.Socket(unix.AF_ROUTE, unix.SOCK_RAW, unix.AF_UNSPEC)
if err != nil {
tun.fd.Close()
return nil, err
}
go tun.routineRouteListener(tunIfindex)
err = tun.setMTU(mtu)
if err != nil {
tun.Close()
return nil, err
}
return tun, nil
}
func (tun *nativeTun) Name() (string, error) {
gostat, err := tun.fd.Stat()
if err != nil {
tun.name = ""
return "", err
}
stat := gostat.Sys().(*syscall.Stat_t)
tun.name = fmt.Sprintf("tun%d", stat.Rdev%256)
return tun.name, nil
}
func (tun *nativeTun) File() *os.File {
return tun.fd
}
func (tun *nativeTun) Events() chan TUNEvent {
return tun.events
}
func (tun *nativeTun) doRead(buff []byte, offset int) (int, error) {
select {
case err := <-tun.errors:
return 0, err
default:
buff := buff[offset-4:]
n, err := tun.fd.Read(buff[:])
if n < 4 {
return 0, err
}
return n - 4, err
}
}
func (tun *nativeTun) Read(buff []byte, offset int) (int, error) {
for {
n, err := tun.doRead(buff, offset)
if err == nil || !rwcancel.RetryAfterError(err) {
return n, err
}
if !tun.rwcancel.ReadyRead() {
return 0, errors.New("tun device closed")
}
}
}
func (tun *nativeTun) Write(buff []byte, offset int) (int, error) {
// reserve space for header
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
}
// write
return tun.fd.Write(buff)
}
func (tun *nativeTun) Close() error {
var err3 error
err1 := tun.rwcancel.Cancel()
err2 := tun.fd.Close()
if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err3 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil {
close(tun.events)
}
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (tun *nativeTun) setMTU(n int) error {
// open datagram socket
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return err
}
defer unix.Close(fd)
// do ioctl call
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name)
}
return nil
}
func (tun *nativeTun) MTU() (int, error) {
// open datagram socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return 0, err
}
defer unix.Close(fd)
// do ioctl call
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
}
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
}

View File

@@ -1,483 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/
package main
import (
"encoding/binary"
"errors"
"fmt"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"net"
"sync"
"syscall"
"time"
"unsafe"
)
/* Relies on the OpenVPN TAP-Windows driver (NDIS 6 version)
*
* https://github.com/OpenVPN/tap-windows
*/
type NativeTUN struct {
fd windows.Handle
rl sync.Mutex
wl sync.Mutex
ro *windows.Overlapped
wo *windows.Overlapped
events chan TUNEvent
name string
}
const (
METHOD_BUFFERED = 0
ComponentID = "tap0901" // tap0801
)
func ctl_code(device_type, function, method, access uint32) uint32 {
return (device_type << 16) | (access << 14) | (function << 2) | method
}
func TAP_CONTROL_CODE(request, method uint32) uint32 {
return ctl_code(file_device_unknown, request, method, 0)
}
var (
errIfceNameNotFound = errors.New("Failed to find the name of interface")
TAP_IOCTL_GET_MAC = TAP_CONTROL_CODE(1, METHOD_BUFFERED)
TAP_IOCTL_GET_VERSION = TAP_CONTROL_CODE(2, METHOD_BUFFERED)
TAP_IOCTL_GET_MTU = TAP_CONTROL_CODE(3, METHOD_BUFFERED)
TAP_IOCTL_GET_INFO = TAP_CONTROL_CODE(4, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_POINT_TO_POINT = TAP_CONTROL_CODE(5, METHOD_BUFFERED)
TAP_IOCTL_SET_MEDIA_STATUS = TAP_CONTROL_CODE(6, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_DHCP_MASQ = TAP_CONTROL_CODE(7, METHOD_BUFFERED)
TAP_IOCTL_GET_LOG_LINE = TAP_CONTROL_CODE(8, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_DHCP_SET_OPT = TAP_CONTROL_CODE(9, METHOD_BUFFERED)
TAP_IOCTL_CONFIG_TUN = TAP_CONTROL_CODE(10, METHOD_BUFFERED)
file_device_unknown = uint32(0x00000022)
nCreateEvent,
nResetEvent,
nGetOverlappedResult uintptr
)
func init() {
k32, err := windows.LoadLibrary("kernel32.dll")
if err != nil {
panic("LoadLibrary " + err.Error())
}
defer windows.FreeLibrary(k32)
nCreateEvent = getProcAddr(k32, "CreateEventW")
nResetEvent = getProcAddr(k32, "ResetEvent")
nGetOverlappedResult = getProcAddr(k32, "GetOverlappedResult")
}
/* implementation of the read/write/closer interface */
func getProcAddr(lib windows.Handle, name string) uintptr {
addr, err := windows.GetProcAddress(lib, name)
if err != nil {
panic(name + " " + err.Error())
}
return addr
}
func resetEvent(h windows.Handle) error {
r, _, err := syscall.Syscall(nResetEvent, 1, uintptr(h), 0, 0)
if r == 0 {
return err
}
return nil
}
func getOverlappedResult(h windows.Handle, overlapped *windows.Overlapped) (int, error) {
var n int
r, _, err := syscall.Syscall6(
nGetOverlappedResult,
4,
uintptr(h),
uintptr(unsafe.Pointer(overlapped)),
uintptr(unsafe.Pointer(&n)), 1, 0, 0)
if r == 0 {
return n, err
}
return n, nil
}
func newOverlapped() (*windows.Overlapped, error) {
var overlapped windows.Overlapped
r, _, err := syscall.Syscall6(nCreateEvent, 4, 0, 1, 0, 0, 0, 0)
if r == 0 {
return nil, err
}
overlapped.HEvent = windows.Handle(r)
return &overlapped, nil
}
func (f *NativeTUN) Events() chan TUNEvent {
return f.events
}
func (f *NativeTUN) Close() error {
close(f.events)
err := windows.Close(f.fd)
return err
}
func (f *NativeTUN) Write(b []byte) (int, error) {
f.wl.Lock()
defer f.wl.Unlock()
if err := resetEvent(f.wo.HEvent); err != nil {
return 0, err
}
var n uint32
err := windows.WriteFile(f.fd, b, &n, f.wo)
if err != nil && err != windows.ERROR_IO_PENDING {
return int(n), err
}
return getOverlappedResult(f.fd, f.wo)
}
func (f *NativeTUN) Read(b []byte) (int, error) {
f.rl.Lock()
defer f.rl.Unlock()
if err := resetEvent(f.ro.HEvent); err != nil {
return 0, err
}
var done uint32
err := windows.ReadFile(f.fd, b, &done, f.ro)
if err != nil && err != windows.ERROR_IO_PENDING {
return int(done), err
}
return getOverlappedResult(f.fd, f.ro)
}
func getdeviceid(
targetComponentId string,
targetDeviceName string,
) (deviceid string, err error) {
getName := func(instanceId string) (string, error) {
path := fmt.Sprintf(
`SYSTEM\CurrentControlSet\Control\Network\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s\Connection`,
instanceId,
)
key, err := registry.OpenKey(
registry.LOCAL_MACHINE,
path,
registry.READ,
)
if err != nil {
return "", err
}
defer key.Close()
val, _, err := key.GetStringValue("Name")
key.Close()
return val, err
}
getInstanceId := func(keyName string) (string, string, error) {
path := fmt.Sprintf(
`SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}\%s`,
keyName,
)
key, err := registry.OpenKey(
registry.LOCAL_MACHINE,
path,
registry.READ,
)
if err != nil {
return "", "", err
}
defer key.Close()
componentId, _, err := key.GetStringValue("ComponentId")
if err != nil {
return "", "", err
}
instanceId, _, err := key.GetStringValue("NetCfgInstanceId")
return componentId, instanceId, err
}
// find list of all network devices
k, err := registry.OpenKey(
registry.LOCAL_MACHINE,
`SYSTEM\CurrentControlSet\Control\Class\{4D36E972-E325-11CE-BFC1-08002BE10318}`,
registry.READ,
)
if err != nil {
return "", fmt.Errorf("Failed to open the adapter registry, TAP driver may be not installed, %v", err)
}
defer k.Close()
keys, err := k.ReadSubKeyNames(-1)
if err != nil {
return "", err
}
// look for matching component id and name
var componentFound bool
for _, v := range keys {
componentId, instanceId, err := getInstanceId(v)
if err != nil || componentId != targetComponentId {
continue
}
componentFound = true
deviceName, err := getName(instanceId)
if err != nil || deviceName != targetDeviceName {
continue
}
return instanceId, nil
}
// provide a descriptive error message
if componentFound {
return "", fmt.Errorf("Unable to find tun/tap device with name = %s", targetDeviceName)
}
return "", fmt.Errorf(
"Unable to find device in registry with ComponentId = %s, is tap-windows installed?",
targetComponentId,
)
}
// setStatus is used to bring up or bring down the interface
func setStatus(fd windows.Handle, status bool) error {
var code [4]byte
if status {
binary.LittleEndian.PutUint32(code[:], 1)
}
var bytesReturned uint32
rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
return windows.DeviceIoControl(
fd,
TAP_IOCTL_SET_MEDIA_STATUS,
&code[0],
uint32(4),
&rdbbuf[0],
uint32(len(rdbbuf)),
&bytesReturned,
nil,
)
}
/* When operating in TUN mode we must assign an ip address & subnet to the device.
*
*/
func setTUN(fd windows.Handle, network string) error {
var bytesReturned uint32
rdbbuf := make([]byte, windows.MAXIMUM_REPARSE_DATA_BUFFER_SIZE)
localIP, remoteNet, err := net.ParseCIDR(network)
if err != nil {
return fmt.Errorf("Failed to parse network CIDR in config, %v", err)
}
if localIP.To4() == nil {
return fmt.Errorf("Provided network(%s) is not a valid IPv4 address", network)
}
var param [12]byte
copy(param[0:4], localIP.To4())
copy(param[4:8], remoteNet.IP.To4())
copy(param[8:12], remoteNet.Mask)
return windows.DeviceIoControl(
fd,
TAP_IOCTL_CONFIG_TUN,
&param[0],
uint32(12),
&rdbbuf[0],
uint32(len(rdbbuf)),
&bytesReturned,
nil,
)
}
func (tun *NativeTUN) MTU() (int, error) {
var mtu [4]byte
var bytesReturned uint32
err := windows.DeviceIoControl(
tun.fd,
TAP_IOCTL_GET_MTU,
&mtu[0],
uint32(len(mtu)),
&mtu[0],
uint32(len(mtu)),
&bytesReturned,
nil,
)
val := binary.LittleEndian.Uint32(mtu[:])
return int(val), err
}
func (tun *NativeTUN) Name() string {
return tun.name
}
func CreateTUN(name string) (TUNDevice, error) {
// find the device in registry.
deviceid, err := getdeviceid(ComponentID, name)
if err != nil {
return nil, err
}
path := "\\\\.\\Global\\" + deviceid + ".tap"
pathp, err := windows.UTF16PtrFromString(path)
if err != nil {
return nil, err
}
// create TUN device
handle, err := windows.CreateFile(
pathp,
windows.GENERIC_READ|windows.GENERIC_WRITE,
0,
nil,
windows.OPEN_EXISTING,
windows.FILE_ATTRIBUTE_SYSTEM|windows.FILE_FLAG_OVERLAPPED,
0,
)
if err != nil {
return nil, err
}
ro, err := newOverlapped()
if err != nil {
windows.Close(handle)
return nil, err
}
wo, err := newOverlapped()
if err != nil {
windows.Close(handle)
return nil, err
}
tun := &NativeTUN{
fd: handle,
name: name,
ro: ro,
wo: wo,
events: make(chan TUNEvent, 5),
}
// find addresses of interface
// TODO: fix this hack, the question is how
inter, err := net.InterfaceByName(name)
if err != nil {
windows.Close(handle)
return nil, err
}
addrs, err := inter.Addrs()
if err != nil {
windows.Close(handle)
return nil, err
}
var ip net.IP
for _, addr := range addrs {
ip = func() net.IP {
switch v := addr.(type) {
case *net.IPNet:
return v.IP.To4()
case *net.IPAddr:
return v.IP.To4()
}
return nil
}()
if ip != nil {
break
}
}
if ip == nil {
windows.Close(handle)
return nil, errors.New("No IPv4 address found for interface")
}
// bring up device.
if err := setStatus(handle, true); err != nil {
windows.Close(handle)
return nil, err
}
// set tun mode
mask := ip.String() + "/0"
if err := setTUN(handle, mask); err != nil {
windows.Close(handle)
return nil, err
}
// start listener
go func(native *NativeTUN, ifname string) {
// TODO: Fix this very niave implementation
var (
statusUp bool
statusMTU int
)
for ; ; time.Sleep(time.Second) {
intr, err := net.InterfaceByName(name)
if err != nil {
// TODO: handle
return
}
// Up / Down event
up := (intr.Flags & net.FlagUp) != 0
if up != statusUp && up {
native.events <- TUNEventUp
}
if up != statusUp && !up {
native.events <- TUNEventDown
}
statusUp = up
// MTU changes
if intr.MTU != statusMTU {
native.events <- TUNEventMTUUpdate
}
statusMTU = intr.MTU
}
}(tun, name)
return tun, nil
}

49
uapi.go
View File

@@ -1,7 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -75,6 +74,7 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send("public_key=" + peer.handshake.remoteStatic.ToHex()) send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex()) send("preshared_key=" + peer.handshake.presharedKey.ToHex())
send("protocol_version=1")
if peer.endpoint != nil { if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.DstToString()) send("endpoint=" + peer.endpoint.DstToString())
} }
@@ -85,8 +85,8 @@ func ipcGetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
send(fmt.Sprintf("last_handshake_time_sec=%d", secs)) send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
send(fmt.Sprintf("last_handshake_time_nsec=%d", nano)) send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
send(fmt.Sprintf("tx_bytes=%d", peer.stats.txBytes)) send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
send(fmt.Sprintf("rx_bytes=%d", peer.stats.rxBytes)) send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval)) send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.allowedips.EntriesForPeer(peer) { for _, ip := range device.allowedips.EntriesForPeer(peer) {
@@ -147,7 +147,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Updating device private key") logDebug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk) device.SetPrivateKey(sk)
case "listen_port": case "listen_port":
@@ -211,7 +211,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
device.RemoveAllPeers() device.RemoveAllPeers()
default: default:
logError.Println("Invalid UAPI key (device configuration):", key) logError.Println("Invalid UAPI device key:", key)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
} }
@@ -226,7 +226,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
var publicKey NoisePublicKey var publicKey NoisePublicKey
err := publicKey.FromHex(value) err := publicKey.FromHex(value)
if err != nil { if err != nil {
logError.Println("Failed to get peer by public_key:", err) logError.Println("Failed to get peer by public key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
@@ -248,7 +248,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
logDebug.Println("UAPI: Created new peer:", peer) logDebug.Println(peer, "- UAPI: Created")
} }
case "remove": case "remove":
@@ -260,7 +260,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
if !dummy { if !dummy {
logDebug.Println("UAPI: Removing peer:", peer) logDebug.Println(peer, "- UAPI: Removing")
device.RemovePeer(peer.handshake.remoteStatic) device.RemovePeer(peer.handshake.remoteStatic)
} }
peer = &Peer{} peer = &Peer{}
@@ -270,14 +270,14 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update PSK // update PSK
logDebug.Println("UAPI: Updating pre-shared key for peer:", peer) logDebug.Println(peer, "- UAPI: Updating preshared key")
peer.handshake.mutex.Lock() peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value) err := peer.handshake.presharedKey.FromHex(value)
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
if err != nil { if err != nil {
logError.Println("Failed to set preshared_key:", err) logError.Println("Failed to set preshared key:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
@@ -285,7 +285,7 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// set endpoint destination // set endpoint destination
logDebug.Println("UAPI: Updating endpoint for peer:", peer) logDebug.Println(peer, "- UAPI: Updating endpoint")
err := func() error { err := func() error {
peer.mutex.Lock() peer.mutex.Lock()
@@ -307,11 +307,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
// update persistent keepalive interval // update persistent keepalive interval
logDebug.Println("UAPI: Updating persistent_keepalive_interval for peer:", peer) logDebug.Println(peer, "- UAPI: Updating persistent keepalive interva")
secs, err := strconv.ParseUint(value, 10, 16) secs, err := strconv.ParseUint(value, 10, 16)
if err != nil { if err != nil {
logError.Println("Failed to set persistent_keepalive_interval:", err) logError.Println("Failed to set persistent keepalive interval:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
@@ -332,10 +332,10 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "replace_allowed_ips": case "replace_allowed_ips":
logDebug.Println("UAPI: Removing all allowedips for peer:", peer) logDebug.Println(peer, "- UAPI: Removing all allowedips")
if value != "true" { if value != "true" {
logError.Println("Failed to set replace_allowed_ips, invalid value:", value) logError.Println("Failed to replace allowedips, invalid value:", value)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
@@ -347,11 +347,11 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
case "allowed_ip": case "allowed_ip":
logDebug.Println("UAPI: Adding allowed_ip to peer:", peer) logDebug.Println(peer, "- UAPI: Adding allowedip")
_, network, err := net.ParseCIDR(value) _, network, err := net.ParseCIDR(value)
if err != nil { if err != nil {
logError.Println("Failed to set allowed_ip:", err) logError.Println("Failed to set allowed ip:", err)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
@@ -362,8 +362,15 @@ func ipcSetOperation(device *Device, socket *bufio.ReadWriter) *IPCError {
ones, _ := network.Mask.Size() ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer) device.allowedips.Insert(network.IP, uint(ones), peer)
case "protocol_version":
if value != "1" {
logError.Println("Invalid protocol version:", value)
return &IPCError{Code: ipcErrorInvalid}
}
default: default:
logError.Println("Invalid UAPI key (peer configuration):", key) logError.Println("Invalid UAPI peer key:", key)
return &IPCError{Code: ipcErrorInvalid} return &IPCError{Code: ipcErrorInvalid}
} }
} }
@@ -397,11 +404,11 @@ func ipcHandle(device *Device, socket net.Conn) {
switch op { switch op {
case "set=1\n": case "set=1\n":
device.log.Debug.Println("Config, set operation") device.log.Debug.Println("UAPI: Set operation")
status = ipcSetOperation(device, buffered) status = ipcSetOperation(device, buffered)
case "get=1\n": case "get=1\n":
device.log.Debug.Println("Config, get operation") device.log.Debug.Println("UAPI: Get operation")
status = ipcGetOperation(device, buffered) status = ipcGetOperation(device, buffered)
default: default:

View File

@@ -1,7 +1,8 @@
// +build darwin freebsd openbsd
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
@@ -13,14 +14,16 @@ import (
"net" "net"
"os" "os"
"path" "path"
"unsafe"
) )
var socketDirectory = "/var/run/wireguard"
const ( const (
ipcErrorIO = -int64(unix.EIO) ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO) ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL) ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE) ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard"
socketName = "%s.sock" socketName = "%s.sock"
) )
@@ -91,7 +94,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
uapi.keventFd, err = unix.Open(socketDirectory, unix.O_EVTONLY, 0) uapi.keventFd, err = unix.Open(socketDirectory, unix.O_RDONLY, 0)
if err != nil { if err != nil {
unix.Close(uapi.kqueueFd) unix.Close(uapi.kqueueFd)
return nil, err return nil, err
@@ -99,11 +102,13 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
go func(l *UAPIListener) { go func(l *UAPIListener) {
event := unix.Kevent_t{ event := unix.Kevent_t{
Ident: uint64(uapi.keventFd),
Filter: unix.EVFILT_VNODE, Filter: unix.EVFILT_VNODE,
Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT, Flags: unix.EV_ADD | unix.EV_ENABLE | unix.EV_ONESHOT,
Fflags: unix.NOTE_WRITE, Fflags: unix.NOTE_WRITE,
} }
// Allow this assignment to work with both the 32-bit and 64-bit version
// of the above struct. If you know another way, please submit a patch.
*(*uintptr)(unsafe.Pointer(&event.Ident)) = uintptr(uapi.keventFd)
events := make([]unix.Kevent_t, 1) events := make([]unix.Kevent_t, 1)
n := 1 n := 1
var kerr error var kerr error
@@ -145,7 +150,7 @@ func UAPIOpen(name string) (*os.File, error) {
// check if path exist // check if path exist
err := os.MkdirAll(socketDirectory, 0700) err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
return nil, err return nil, err
} }
@@ -162,6 +167,7 @@ func UAPIOpen(name string) (*os.File, error) {
return nil, err return nil, err
} }
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) { listener, err := func() (*net.UnixListener, error) {
// initial connection attempt // initial connection attempt
@@ -186,6 +192,7 @@ func UAPIOpen(name string) (*os.File, error) {
} }
return net.ListenUnix("unix", addr) return net.ListenUnix("unix", addr)
}() }()
unix.Umask(oldUmask)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -1,27 +1,27 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/ */
package main package main
import ( import (
"./rwcancel"
"errors" "errors"
"fmt" "fmt"
"git.zx2c4.com/wireguard-go/rwcancel"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
"net" "net"
"os" "os"
"path" "path"
) )
var socketDirectory = "/var/run/wireguard"
const ( const (
ipcErrorIO = -int64(unix.EIO) ipcErrorIO = -int64(unix.EIO)
ipcErrorProtocol = -int64(unix.EPROTO) ipcErrorProtocol = -int64(unix.EPROTO)
ipcErrorInvalid = -int64(unix.EINVAL) ipcErrorInvalid = -int64(unix.EINVAL)
ipcErrorPortInUse = -int64(unix.EADDRINUSE) ipcErrorPortInUse = -int64(unix.EADDRINUSE)
socketDirectory = "/var/run/wireguard"
socketName = "%s.sock" socketName = "%s.sock"
) )
@@ -147,7 +147,7 @@ func UAPIOpen(name string) (*os.File, error) {
// check if path exist // check if path exist
err := os.MkdirAll(socketDirectory, 0700) err := os.MkdirAll(socketDirectory, 0755)
if err != nil && !os.IsExist(err) { if err != nil && !os.IsExist(err) {
return nil, err return nil, err
} }
@@ -164,6 +164,7 @@ func UAPIOpen(name string) (*os.File, error) {
return nil, err return nil, err
} }
oldUmask := unix.Umask(0077)
listener, err := func() (*net.UnixListener, error) { listener, err := func() (*net.UnixListener, error) {
// initial connection attempt // initial connection attempt
@@ -188,6 +189,7 @@ func UAPIOpen(name string) (*os.File, error) {
} }
return net.ListenUnix("unix", addr) return net.ListenUnix("unix", addr)
}() }()
unix.Umask(oldUmask)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@@ -1,50 +0,0 @@
/* SPDX-License-Identifier: GPL-2.0
*
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved.
* Copyright (C) 2017-2018 Mathias N. Hall-Andersen <mathias@hall-andersen.dk>.
*/
package main
/* UAPI on windows uses a bidirectional named pipe
*/
import (
"fmt"
"github.com/Microsoft/go-winio"
"golang.org/x/sys/windows"
"net"
)
const (
ipcErrorIO = -int64(windows.ERROR_BROKEN_PIPE)
ipcErrorProtocol = -int64(windows.ERROR_INVALID_NAME)
ipcErrorInvalid = -int64(windows.ERROR_INVALID_PARAMETER)
ipcErrorPortInUse = -int64(windows.ERROR_ALREADY_EXISTS)
)
const PipeNameFmt = "\\\\.\\pipe\\wireguard-ipc-%s"
type UAPIListener struct {
listener net.Listener
}
func (uapi *UAPIListener) Accept() (net.Conn, error) {
return nil, nil
}
func (uapi *UAPIListener) Close() error {
return uapi.listener.Close()
}
func (uapi *UAPIListener) Addr() net.Addr {
return nil
}
func NewUAPIListener(name string) (net.Listener, error) {
path := fmt.Sprintf(PipeNameFmt, name)
return winio.ListenPipe(path, &winio.PipeConfig{
InputBufferSize: 2048,
OutputBufferSize: 2048,
})
}

2
version.go Normal file
View File

@@ -0,0 +1,2 @@
package main
const WireGuardGoVersion = "0.0.20181001"

View File

@@ -1,7 +1,7 @@
/* SPDX-License-Identifier: MIT /* SPDX-License-Identifier: MIT
* *
* Copyright (C) 2016 Andreas Auernhammer. All Rights Reserved. * Copyright (C) 2016 Andreas Auernhammer. All Rights Reserved.
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package xchacha20poly1305 package xchacha20poly1305

View File

@@ -1,6 +1,6 @@
/* SPDX-License-Identifier: GPL-2.0 /* SPDX-License-Identifier: GPL-2.0
* *
* Copyright (C) 2017-2018 Jason A. Donenfeld <Jason@zx2c4.com>. All Rights Reserved. * Copyright (C) 2017-2018 WireGuard LLC. All Rights Reserved.
*/ */
package xchacha20poly1305 package xchacha20poly1305