64 Commits

Author SHA1 Message Date
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
05b03c6750 version: bump snapshot 2020-01-21 16:27:19 +01:00
Jason A. Donenfeld
caebdfe9d0 tun: darwin: ignore ENOMEM errors
Coauthored-by: Andrej Mihajlov <and@mullvad.net>
2020-01-15 13:39:37 -05:00
Jason A. Donenfeld
4fa2ea6a2d tun: windows: serialize write calls 2020-01-07 11:40:45 -05:00
Jason A. Donenfeld
89dd065e53 README: update repo urls 2019-12-30 11:53:39 +01:00
Jason A. Donenfeld
ddfad453cf device: SendmsgN mutates the input sockaddr
So we take a new granular lock to prevent concurrent writes from
racing.

WARNING: DATA RACE
Write at 0x00c0011f2740 by goroutine 27:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:318
+0x4b8

Previous write at 0x00c0011f2740 by goroutine 386:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.expiredRetransmitHandshake()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:110
+0x40c
  golang.zx2c4.com/wireguard/device.(*Peer).NewTimer.func1()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:42
+0xd8

Goroutine 27 (running) created at:
  golang.zx2c4.com/wireguard/device.NewDevice()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/device.go:322
+0x5e8
  main.main()
      /go/src/x/main.go:102 +0x58e

Goroutine 386 (finished) created at:
  time.goFunc()
      /usr/local/go/src/time/sleep.go:168 +0x51

Reported-by: Ben Burkert <ben@benburkert.com>
2019-11-28 11:11:13 +01:00
Jason A. Donenfeld
2b242f9393 wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially
well.
2019-11-22 13:13:55 +01:00
Jason A. Donenfeld
4cdf805b29 constants: recalculate rekey max based on a one minute flood
Discussed-with: Mathias Hall-Andersen <mathias@hall-andersen.dk>
2019-10-30 14:29:32 +01:00
Jonathan Tooker
f7d0edd2ec global: fix a few typos courtesy of codespell
Signed-off-by: Jonathan Tooker <jonathan.tooker@netprotect.com>
2019-10-22 11:51:25 +02:00
Jason A. Donenfeld
ffffbbcc8a device: allow blackholing sockets 2019-10-21 13:29:57 +02:00
Jason A. Donenfeld
47b02c618b device: remove dead error reporting code 2019-10-21 11:46:54 +02:00
Jason A. Donenfeld
fd23c66fcd namespaceapi: remove tasteless comment 2019-10-21 09:02:29 +02:00
Jason A. Donenfeld
ae492d1b35 device: recheck counters while holding write lock 2019-10-17 15:43:06 +02:00
Jason A. Donenfeld
95fbfccf60 wintun: normalize variable names for their types 2019-10-17 15:30:56 +02:00
Avery Pennarun
c85e4a410f wintun: quickly ignore non-Wintun devices
Some devices take ~2 seconds to enumerate on Windows if we try to get
their instance name.  The hardware id property, on the other hand,
is available right away.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: inlined this to where it makes sense, reused setupapi const]
2019-10-17 15:19:20 +02:00
Avery Pennarun
1b6c8ddbe8 tun: match windows CreateTUN signature to the Linux variant
Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: fix default value]
2019-10-17 15:19:20 +02:00
Avery Pennarun
0abb6b668c rwcancel: handle EINTR and EAGAIN in unixSelect()
On my Chromebook (Linux 4.19.44 in a VM) and on an AWS EC2
machine, select() was sometimes returning EINTR. This is
harmless and just means you should try again. So let's try
again.

This eliminates a problem where the tunnel fails to come up
correctly and the program needs to be restarted.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
2019-10-17 15:19:17 +02:00
David Crawshaw
540d01e54a device: test packets between two fake devices
Signed-off-by: David Crawshaw <crawshaw@tailscale.io>
2019-10-16 11:38:28 +02:00
Jason A. Donenfeld
f2ea85e9f9 version: bump snapshot 2019-10-12 22:34:10 +02:00
Jason A. Donenfeld
222f0f8000 Makefile: remove v prefix 2019-10-08 16:48:18 +02:00
Jason A. Donenfeld
1f146a5e7a wintun: expose version 2019-10-08 09:58:58 +02:00
Jason A. Donenfeld
f2501aa6c8 uapi: allow preventing creation of new peers when updating
This enables race-free updates for wg-dynamic and similar tools.

Suggested-by: Thomas Gschwantner <tharre3@gmail.com>
2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
cb8d01f58a mod: bump versions 2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
01f8ef4e84 winpipe: use x/sys/windows instead of syscall 2019-09-16 23:39:16 -06:00
Jason A. Donenfeld
70f6c42556 wintun: use correct length for security attributes 2019-09-16 19:38:33 -06:00
Jason A. Donenfeld
bb0b2514c0 tun: windows: unify error message format 2019-09-08 13:52:44 -05:00
Jason A. Donenfeld
7c97fdb1e3 version: bump snapshot 2019-09-08 10:56:55 -05:00
Jason A. Donenfeld
84b5a4d83d main: simplify warnings 2019-09-08 10:56:00 -05:00
Jason A. Donenfeld
4cd06c0925 tun: openbsd: check for interface already being up
In some cases, we operate on an already-up interface, or the user brings
up the interface before we start monitoring. For those situations, we
should first check if the interface is already up.

This still technically races between the initial check and the start of
the route loop, but fixing that is a bit ugly and probably not worth it
at the moment.

Reported-by: Theo Buehler <tb@theobuehler.org>
2019-09-07 00:13:23 -05:00
Jason A. Donenfeld
d12eb91f9a namespaceapi: AddSIDToBoundaryDescriptor modifies the handle 2019-09-05 21:48:21 -06:00
Jason A. Donenfeld
73d3bd9cd5 wintun: take mutex first always
This prevents an ABA deadlock with setupapi's internal locks.
2019-09-01 21:32:28 -06:00
Jason A. Donenfeld
f3dba4c194 wintun: consider abandoned mutexes as released 2019-09-01 21:25:47 -06:00
Jason A. Donenfeld
7937840f96 ipc: windows: use protected prefix 2019-08-31 07:48:42 -06:00
Jason A. Donenfeld
e4b957183c winpipe: enforce ownership of client connection 2019-08-30 13:21:47 -06:00
Jason A. Donenfeld
950ca2ba8c wintun: put mutex into private namespace 2019-08-30 11:03:21 -06:00
Jason A. Donenfeld
df2bf34373 namespaceapi: fix mistake 2019-08-30 09:59:36 -06:00
Simon Rozman
a12b765784 namespaceapi: initial version
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-30 15:34:17 +02:00
Jason A. Donenfeld
14df9c3e75 wintun: take mutex so that deletion uses the right name 2019-08-30 15:34:17 +02:00
Jason A. Donenfeld
353f0956bc wintun: move ring constants into module 2019-08-29 13:22:17 -06:00
Jason A. Donenfeld
fa7763c268 wintun: delete all interfaces is not used anymore 2019-08-29 12:22:15 -06:00
Jason A. Donenfeld
d94bae8348 wintun: Wintun->Interface 2019-08-29 12:20:40 -06:00
Jason A. Donenfeld
7689d09336 wintun: keep reference to pool in wintun object 2019-08-29 12:13:16 -06:00
Simon Rozman
69c26dc258 wintun: introduce adapter pools
This makes wintun package reusable for non-WireGuard applications.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-29 18:00:44 +02:00
Jason A. Donenfeld
e862131d3c wintun: simplify rename logic 2019-08-28 19:31:20 -06:00
Jason A. Donenfeld
da28a3e9f3 wintun: give better errors when ndis interface listing fails 2019-08-28 08:39:26 -06:00
Jason A. Donenfeld
3bf3322b2c wintun: also check for numbered suffix and friendly name 2019-08-28 08:08:07 -06:00
Simon Rozman
7305b4ce93 wintun: upgrade deleting all interfaces and make it reusable
DeleteAllInterfaces() didn't check if SPDRP_DEVICEDESC == "WireGuard
Tunnel". It deleted _all_ Wintun adapters, not just WireGuard's.

Furthermore, the DeleteAllInterfaces() was upgraded into a new function
called DeleteMatchingInterfaces() for selectively deletion. This will
be used by WireGuard to clean stale Wintun adapters.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-28 11:39:01 +02:00
Jason A. Donenfeld
26fb615b11 wintun: cleanup earlier 2019-08-27 11:59:15 -06:00
Jason A. Donenfeld
7fbb24afaa wintun: rename duplicate adapters instead of ourselves 2019-08-27 11:59:15 -06:00
Jason A. Donenfeld
d9008ac35c wintun: match suffix numbers 2019-08-26 14:46:43 -06:00
Jason A. Donenfeld
f8198c0428 device: getsockname on linux to determine port
It turns out Go isn't passing the pointer properly so we wound up with a
zero port every time.
2019-08-25 12:45:13 -06:00
Jason A. Donenfeld
0c540ad60e wintun: make description consistent across fields 2019-08-24 12:29:17 +02:00
Jason A. Donenfeld
3cedc22d7b wintun: try multiple names until one isn't a duplicate 2019-08-22 08:52:59 +02:00
Jason A. Donenfeld
68fea631d8 wintun: use nci.dll directly instead of buggy netshell 2019-08-21 09:16:12 +02:00
Jason A. Donenfeld
ef23100a4f wintun: set friendly a bit better
This is still wrong, but NETSETUPPKEY_Driver_FriendlyName seems a bit
tricky to use.
2019-08-20 16:06:55 +02:00
Jason A. Donenfeld
eb786cd7c1 wintun: also set friendly name after setting interface name 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
333de75370 wintun: defer requires unique variable 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
d20459dc69 wintun: set adapter description name 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
01786286c1 tun: windows: don't spin unless we really need it 2019-08-19 10:12:50 +02:00
47 changed files with 1625 additions and 775 deletions

View File

@@ -1,30 +1,16 @@
PREFIX ?= /usr PREFIX ?= /usr
DESTDIR ?= DESTDIR ?=
BINDIR ?= $(PREFIX)/bin BINDIR ?= $(PREFIX)/bin
export GOPATH ?= $(CURDIR)/.gopath
export GO111MODULE := on export GO111MODULE := on
all: generate-version-and-build all: generate-version-and-build
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
else ifeq ($(shell go env GOOS),linux)
ireallywantobuildon_linux.go:
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"
clean-ireallywantobuildon_linux.go:
@rm -f ireallywantobuildon_linux.go
.PHONY: clean-ireallywantobuildon_linux.go
clean: clean-ireallywantobuildon_linux.go
wireguard-go: ireallywantobuildon_linux.go
endif
MAKEFLAGS += --no-print-directory MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \ ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > device/version.go && \ echo "$$ver" > device/version.go && \
git update-index --assume-unchanged device/version.go || true git update-index --assume-unchanged device/version.go || true

View File

@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0 $ wireguard-go -f wg0
``` ```
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`. To run with more logging you may set the environment variable `LOG_LEVEL=debug`.

View File

@@ -18,7 +18,7 @@ const (
sockoptIPV6_UNICAST_IF = 31 sockoptIPV6_UNICAST_IF = 31
) )
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error { func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4) bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex) binary.BigEndian.PutUint32(bytes, interfaceIndex)
@@ -41,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole4 = blackhole
return nil return nil
} }
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error { func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
@@ -58,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole6 = blackhole
return nil return nil
} }

View File

@@ -21,8 +21,10 @@ import (
*/ */
type nativeBind struct { type nativeBind struct {
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
} }
type NativeEndpoint net.UDPAddr type NativeEndpoint net.UDPAddr
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
if bind.ipv4 == nil { if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole4 {
return nil
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else { } else {
if bind.ipv6 == nil { if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} }
return err return err

View File

@@ -7,7 +7,7 @@
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port * WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code: * of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
* *
* Currently there is no way to achieve this within the net package: * Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
@@ -43,6 +43,7 @@ type IPv6Source struct {
} }
type NativeEndpoint struct { type NativeEndpoint struct {
sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool isV6 bool
@@ -117,7 +118,7 @@ func createNetlinkRouteSocket() (int, error) {
} }
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), Groups: unix.RTMGRP_IPV4_ROUTE,
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
go bind.routineRouteListener(device) go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful // attempt ipv6 bind, update port if successful
bind.sock6, newPort, err = create6(port) bind.sock6, newPort, err = create6(port)
if err != nil { if err != nil {
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
port = newPort port = newPort
} }
// attempt ipv4 bind, update port if succesful // attempt ipv4 bind, update port if successful
bind.sock4, newPort, err = create4(port) bind.sock4, newPort, err = create4(port)
if err != nil { if err != nil {
@@ -391,6 +392,11 @@ func create4(port uint16) (int, uint16, error) {
return FD_ERR, 0, err return FD_ERR, 0, err
} }
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet4).Port
}
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
@@ -450,6 +456,11 @@ func create6(port uint16) (int, uint16, error) {
return FD_ERR, 0, err return FD_ERR, 0, err
} }
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet6).Port
}
return fd, uint16(addr.Port), err return fd, uint16(addr.Port), err
} }
@@ -472,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
}, },
} }
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -483,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{} cmsg.pktinfo = unix.Inet4Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
} }
return err return err
@@ -512,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0 cmsg.pktinfo.Ifindex = 0
} }
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -523,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{} cmsg.pktinfo = unix.Inet6Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
} }
return err return err
@@ -531,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // construct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
@@ -563,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // construct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr

View File

@@ -12,7 +12,7 @@ import (
/* Specification constants */ /* Specification constants */
const ( const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90

View File

@@ -236,23 +236,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations // do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
if rmKey { expiredPeers = append(expiredPeers, peer)
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
}
} }
for _, peer := range lockedPeers { for _, peer := range lockedPeers {

View File

@@ -5,54 +5,212 @@
package device package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary"
"io"
"net"
"os"
"strings"
"testing" "testing"
"time"
"golang.zx2c4.com/wireguard/tun"
) )
func TestDevice(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
// TODO(crawshaw): pick unused ports on localhost
// prepare tun devices for generating traffic cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
listen_port=53511
tun1 := newDummyTUN("tun1") replace_peers=true
tun2 := newDummyTUN("tun2") public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
protocol_version=1
_ = tun1 replace_allowed_ips=true
_ = tun2 allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
// prepare endpoints tun1 := NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
end1, err := CreateDummyEndpoint() dev1.Up()
if err != nil { defer dev1.Close()
t.Error("failed to create endpoint:", err.Error()) if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "") cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
device := NewDevice(tun, logger) listen_port=53512
device.SetPrivateKey(sk) replace_peers=true
return device public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
t.Fatal(err)
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("ping did not transit")
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("return ping did not transit")
}
})
}
func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
} }
func assertNil(t *testing.T, err error) { func assertNil(t *testing.T, err error) {
@@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) {
t.Fatal(a, "!=", b) t.Fatal(a, "!=", b)
} }
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View File

@@ -39,13 +39,13 @@ const (
) )
const ( const (
MessageInitiationSize = 148 // size of handshake initation message MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
) )
const ( const (
@@ -154,6 +154,7 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key // create ephemeral key
var err error var err error
handshake.hash = InitialHash handshake.hash = InitialHash
handshake.chainKey = InitialChainKey handshake.chainKey = InitialChainKey
@@ -176,56 +172,53 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err return nil, err
} }
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
} }
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
func() { if isZero(ss[:]) {
var key [chacha20poly1305.KeySize]byte return nil, errZeroECDHResult
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
KDF2( var key [chacha20poly1305.KeySize]byte
&handshake.chainKey, KDF2(
&key, &handshake.chainKey,
handshake.chainKey[:], &key,
ss[:], handshake.chainKey[:],
) ss[:],
aead, _ := chacha20poly1305.New(key[:]) )
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) aead, _ := chacha20poly1305.New(key[:])
}() aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
timestamp := tai64n.Now() timestamp := tai64n.Now()
func() { aead, _ = chacha20poly1305.New(key[:])
var key [chacha20poly1305.KeySize]byte aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
KDF2(
&handshake.chainKey, // assign index
&key, device.indexTable.Delete(handshake.localIndex)
handshake.chainKey[:], msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
handshake.precomputedStaticStatic[:], if err != nil {
) return nil, err
aead, _ := chacha20poly1305.New(key[:]) }
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) handshake.localIndex = msg.Sender
}()
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
func() { var key [chacha20poly1305.KeySize]byte
var key [chacha20poly1305.KeySize]byte ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if isZero(ss[:]) {
KDF2(&chainKey, &key, chainKey[:], ss[:]) return nil
aead, _ := chacha20poly1305.New(key[:]) }
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
}() aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil { if err != nil {
return nil return nil
} }
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
} }
handshake := &peer.handshake handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity // verify identity
var timestamp tai64n.Timestamp var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock() handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2( KDF2(
&chainKey, &chainKey,
&key, &key,
chainKey[:], chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil { if err != nil {
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
@@ -315,8 +309,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp if timestamp.After(handshake.lastTimestamp) {
handshake.lastInitiationConsumption = time.Now() handshake.lastTimestamp = timestamp
}
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = HandshakeInitiationConsumed handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()

View File

@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return return
} }
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string { func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:]) return hex.EncodeToString(key[:])
} }

View File

@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil peer.endpoint = nil
// conditionally add // add
if !ssIsZero { device.peers.keyMap[pk] = peer
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer // start peer

View File

@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
writer := bytes.NewBuffer(buff[:0]) writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil { return nil
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
} }
func (peer *Peer) keepKeyFreshSending() { func (peer *Peer) keepKeyFreshSending() {
@@ -518,10 +515,18 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16 // pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu)) mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu var paddedSize int
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) if mtu == 0 {
if paddedSize > mtu { paddedSize = (len(elem.packet) + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
paddedSize = mtu } else {
lastUnit := len(elem.packet)
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
} }
for i := len(elem.packet); i < paddedSize; i++ { for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0) elem.packet = append(elem.packet, 0)

View File

@@ -113,6 +113,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
var peer *Peer var peer *Peer
dummy := false dummy := false
createdNewPeer := false
deviceConfig := true deviceConfig := true
for scanner.Scan() { for scanner.Scan() {
@@ -137,7 +138,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
err := sk.FromHex(value) err := sk.FromMaybeZeroHex(value)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid} return &IPCError{ipc.IpcErrorInvalid}
@@ -237,7 +238,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
peer = device.LookupPeer(publicKey) peer = device.LookupPeer(publicKey)
} }
if peer == nil { createdNewPeer = peer == nil
if createdNewPeer {
peer, err = device.NewPeer(publicKey) peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
@@ -251,6 +253,20 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
} }
} }
case "update_only":
// allow disabling of creation
if value != "true" {
logError.Println("Failed to set update only, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if createdNewPeer && !dummy {
device.RemovePeer(peer.handshake.remoteStatic)
peer = &Peer{}
dummy = true
}
case "remove": case "remove":
// remove currently selected peer from device // remove currently selected peer from device

View File

@@ -1,3 +1,3 @@
package device package device
const WireGuardGoVersion = "0.0.20190805" const WireGuardGoVersion = "0.0.20200320"

View File

@@ -1,15 +0,0 @@
// +build !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package main
const DoNotUseThisProgramOnLinux = UseTheKernelModuleInstead
// --------------------------------------------------------
// Do not use this on Linux. Instead use the kernel module.
// See wireguard.com/install for more information.
// --------------------------------------------------------

7
go.mod
View File

@@ -3,7 +3,8 @@ module golang.zx2c4.com/wireguard
go 1.12 go 1.12
require ( require (
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
golang.org/x/text v0.3.2
) )

15
go.sum
View File

@@ -1,11 +1,14 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 h1:ZpKuNIejY8P0ExLOVyKhb0WsgG8UdvHXe6TWjY7eL6k= golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 h1:dfGZHvZk057jK2MCeWus/TowKpJ8y4AmooUzdBSR9GU= golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f h1:dHNZYIYdq2QuU6w73vZ/DzesPbVlZVYZTtTZmrnsbQ8= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -8,6 +8,8 @@ package ipc
import ( import (
"net" "net"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/winpipe" "golang.zx2c4.com/wireguard/ipc/winpipe"
) )
@@ -47,14 +49,22 @@ func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr() return l.listener.Addr()
} }
/* SDDL_DEVOBJ_SYS_ALL from the WDK */ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
func init() {
var err error
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil {
panic(err)
}
}
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
config := winpipe.PipeConfig{ config := winpipe.PipeConfig{
SecurityDescriptor: UAPISecurityDescriptor, SecurityDescriptor: UAPISecurityDescriptor,
} }
listener, err := winpipe.ListenPipe("\\\\.\\pipe\\WireGuard\\"+name, &config) listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -13,15 +13,16 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"golang.org/x/sys/windows"
) )
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx //sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort //sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus //sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes //sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult //sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32 type atomicBool int32
@@ -55,7 +56,7 @@ func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{} type timeoutChan chan struct{}
var ioInitOnce sync.Once var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle var ioCompletionPort windows.Handle
// ioResult contains the result of an asynchronous IO operation // ioResult contains the result of an asynchronous IO operation
type ioResult struct { type ioResult struct {
@@ -65,12 +66,12 @@ type ioResult struct {
// ioOperation represents an outstanding asynchronous Win32 IO // ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct { type ioOperation struct {
o syscall.Overlapped o windows.Overlapped
ch chan ioResult ch chan ioResult
} }
func initIo() { func initIo() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff) h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -81,7 +82,7 @@ func initIo() {
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected. // It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct { type win32File struct {
handle syscall.Handle handle windows.Handle
wg sync.WaitGroup wg sync.WaitGroup
wgLock sync.RWMutex wgLock sync.RWMutex
closing atomicBool closing atomicBool
@@ -99,7 +100,7 @@ type deadlineHandler struct {
} }
// makeWin32File makes a new win32File from an existing file handle // makeWin32File makes a new win32File from an existing file handle
func makeWin32File(h syscall.Handle) (*win32File, error) { func makeWin32File(h windows.Handle) (*win32File, error) {
f := &win32File{handle: h} f := &win32File{handle: h}
ioInitOnce.Do(initIo) ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
@@ -115,7 +116,7 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
return f, nil return f, nil
} }
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) { func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h) return makeWin32File(h)
} }
@@ -129,7 +130,7 @@ func (f *win32File) closeHandle() {
cancelIoEx(f.handle, nil) cancelIoEx(f.handle, nil)
f.wg.Wait() f.wg.Wait()
// at this point, no new IO can start // at this point, no new IO can start
syscall.Close(f.handle) windows.Close(f.handle)
f.handle = 0 f.handle = 0
} else { } else {
f.wgLock.Unlock() f.wgLock.Unlock()
@@ -158,12 +159,12 @@ func (f *win32File) prepareIo() (*ioOperation, error) {
} }
// ioCompletionProcessor processes completed async IOs forever // ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h syscall.Handle) { func ioCompletionProcessor(h windows.Handle) {
for { for {
var bytes uint32 var bytes uint32
var key uintptr var key uintptr
var op *ioOperation var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE) err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
if op == nil { if op == nil {
panic(err) panic(err)
} }
@@ -174,7 +175,7 @@ func ioCompletionProcessor(h syscall.Handle) {
// asyncIo processes the return value from ReadFile or WriteFile, blocking until // asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed. // the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING { if err != windows.ERROR_IO_PENDING {
return int(bytes), err return int(bytes), err
} }
@@ -193,7 +194,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
select { select {
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.isSet() { if f.closing.isSet() {
err = ErrFileClosed err = ErrFileClosed
} }
@@ -206,7 +207,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
cancelIoEx(f.handle, &c.o) cancelIoEx(f.handle, &c.o)
r = <-c.ch r = <-c.ch
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
err = ErrTimeout err = ErrTimeout
} }
} }
@@ -231,14 +232,14 @@ func (f *win32File) Read(b []byte) (int, error) {
} }
var bytes uint32 var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o) err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err) n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
// Handle EOF conditions. // Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 { if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE { } else if err == windows.ERROR_BROKEN_PIPE {
return 0, io.EOF return 0, io.EOF
} else { } else {
return n, err return n, err
@@ -258,7 +259,7 @@ func (f *win32File) Write(b []byte) (int, error) {
} }
var bytes uint32 var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o) err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
return n, err return n, err
@@ -273,7 +274,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
} }
func (f *win32File) Flush() error { func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle) return windows.FlushFileBuffers(f.handle)
} }
func (f *win32File) Fd() uintptr { func (f *win32File) Fd() uintptr {

View File

@@ -6,4 +6,4 @@
package winpipe package winpipe
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go

View File

@@ -16,18 +16,19 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe //sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW //sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW //sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo //sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc //sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile //sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb //sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U //sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl //sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
@@ -41,7 +42,7 @@ type objectAttributes struct {
RootDirectory uintptr RootDirectory uintptr
ObjectName *unicodeString ObjectName *unicodeString
Attributes uintptr Attributes uintptr
SecurityDescriptor *securityDescriptor SecurityDescriptor *windows.SECURITY_DESCRIPTOR
SecurityQoS uintptr SecurityQoS uintptr
} }
@@ -51,16 +52,6 @@ type unicodeString struct {
Buffer uintptr Buffer uintptr
} }
type securityDescriptor struct {
Revision byte
Sbz1 byte
Control uint16
Owner uintptr
Group uintptr
Sacl uintptr
Dacl uintptr
}
type ntstatus int32 type ntstatus int32
func (status ntstatus) Err() error { func (status ntstatus) Err() error {
@@ -71,11 +62,6 @@ func (status ntstatus) Err() error {
} }
const ( const (
cERROR_PIPE_BUSY = syscall.Errno(231)
cERROR_NO_DATA = syscall.Errno(232)
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cSECURITY_SQOS_PRESENT = 0x100000 cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0 cSECURITY_ANONYMOUS = 0
@@ -88,8 +74,6 @@ const (
cFILE_PIPE_MESSAGE_TYPE = 1 cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
) )
var ( var (
@@ -170,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
// zero-byte message, ensure that all future Read() calls // zero-byte message, ensure that all future Read() calls
// also return EOF. // also return EOF.
f.readEOF = true f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA { } else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since // and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams. // this package presents all named pipes as byte streams.
@@ -188,17 +172,17 @@ func (s pipeAddress) String() string {
} }
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. // tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) { func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return syscall.Handle(0), ctx.Err() return windows.Handle(0), ctx.Err()
default: default:
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil { if err == nil {
return h, nil return h, nil
} }
if err != cERROR_PIPE_BUSY { if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path} return h, &os.PathError{Err: err, Op: "open", Path: *path}
} }
// Wait 10 msec and try again. This is a rather simplistic // Wait 10 msec and try again. This is a rather simplistic
@@ -211,7 +195,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
// DialPipe connects to a named pipe by path, timing out if the connection // DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use // takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) // a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) { func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
var absTimeout time.Time var absTimeout time.Time
if timeout != nil { if timeout != nil {
absTimeout = time.Now().Add(*timeout) absTimeout = time.Now().Add(*timeout)
@@ -219,7 +203,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
absTimeout = time.Now().Add(time.Second * 2) absTimeout = time.Now().Add(time.Second * 2)
} }
ctx, _ := context.WithDeadline(context.Background(), absTimeout) ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialPipeContext(ctx, path) conn, err := DialPipeContext(ctx, path, expectedOwner)
if err == context.DeadlineExceeded { if err == context.DeadlineExceeded {
return nil, ErrTimeout return nil, ErrTimeout
} }
@@ -228,23 +212,41 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` // DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout. // cancellation or timeout.
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) { func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
var err error var err error
var h syscall.Handle var h windows.Handle
h, err = tryDialPipe(ctx, &path) h, err = tryDialPipe(ctx, &path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if expectedOwner != nil {
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
if err != nil {
windows.Close(h)
return nil, err
}
realOwner, _, err := sd.Owner()
if err != nil {
windows.Close(h)
return nil, err
}
if !realOwner.Equals(expectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
}
}
var flags uint32 var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil) err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil { if err != nil {
windows.Close(h)
return nil, err return nil, err
} }
f, err := makeWin32File(h) f, err := makeWin32File(h)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
@@ -264,7 +266,7 @@ type acceptResponse struct {
} }
type win32PipeListener struct { type win32PipeListener struct {
firstHandle syscall.Handle firstHandle windows.Handle
path string path string
config PipeConfig config PipeConfig
acceptCh chan (chan acceptResponse) acceptCh chan (chan acceptResponse)
@@ -272,8 +274,8 @@ type win32PipeListener struct {
doneCh chan int doneCh chan int
} }
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) { func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
path16, err := syscall.UTF16FromString(path) path16, err := windows.UTF16FromString(path)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
@@ -285,31 +287,32 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
defer localFree(ntPath.Buffer) defer windows.LocalFree(windows.Handle(ntPath.Buffer))
oa.ObjectName = &ntPath oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe. // The security descriptor is only needed for the first pipe.
if first { if first {
if sd != nil { if sd != nil {
len := uint32(len(sd)) oa.SecurityDescriptor = sd
sdb := localAlloc(0, len)
defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
} else { } else {
// Construct the default named pipe security descriptor. // Construct the default named pipe security descriptor.
var dacl uintptr var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err) return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
} }
defer localFree(dacl) defer windows.LocalFree(windows.Handle(dacl))
sd, err := windows.NewSecurityDescriptor()
sdb := &securityDescriptor{ if err != nil {
Revision: 1, return 0, fmt.Errorf("creating new security descriptor: %s", err)
Control: cSE_DACL_PRESENT,
Dacl: dacl,
} }
oa.SecurityDescriptor = sdb if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
return 0, fmt.Errorf("assigning dacl: %s", err)
}
sd, err = sd.ToSelfRelative()
if err != nil {
return 0, fmt.Errorf("converting to self-relative: %s", err)
}
oa.SecurityDescriptor = sd
} }
} }
@@ -319,22 +322,22 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
} }
disposition := uint32(cFILE_OPEN) disposition := uint32(cFILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if first { if first {
disposition = cFILE_CREATE disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system // By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking // will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false. // client connections until the next call with first == false.
access = syscall.SYNCHRONIZE access = windows.SYNCHRONIZE
} }
timeout := int64(-50 * 10000) // 50ms timeout := int64(-50 * 10000) // 50ms
var ( var (
h syscall.Handle h windows.Handle
iosb ioStatusBlock iosb ioStatusBlock
) )
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
@@ -350,7 +353,7 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
} }
f, err := makeWin32File(h) f, err := makeWin32File(h)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
return f, nil return f, nil
@@ -401,7 +404,7 @@ func (l *win32PipeListener) listenerRoutine() {
p, err = l.makeConnectedServerPipe() p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try // If the connection was immediately closed by the client, try
// again. // again.
if err != cERROR_NO_DATA { if err != windows.ERROR_NO_DATA {
break break
} }
} }
@@ -409,7 +412,7 @@ func (l *win32PipeListener) listenerRoutine() {
closed = err == ErrPipeListenerClosed closed = err == ErrPipeListenerClosed
} }
} }
syscall.Close(l.firstHandle) windows.Close(l.firstHandle)
l.firstHandle = 0 l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed. // Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh) close(l.doneCh)
@@ -417,8 +420,8 @@ func (l *win32PipeListener) listenerRoutine() {
// PipeConfig contain configuration for the pipe listener. // PipeConfig contain configuration for the pipe listener.
type PipeConfig struct { type PipeConfig struct {
// SecurityDescriptor contains a Windows security descriptor in SDDL format. // SecurityDescriptor contains a Windows security descriptor.
SecurityDescriptor string SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either // MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in // case the pipe is read in byte mode by default. The only practical difference in
@@ -438,20 +441,10 @@ type PipeConfig struct {
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. // ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
// The pipe must not already exist. // The pipe must not already exist.
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
var (
sd []byte
err error
)
if c == nil { if c == nil {
c = &PipeConfig{} c = &PipeConfig{}
} }
if c.SecurityDescriptor != "" { h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
if err != nil {
return nil, err
}
}
h, err := makeServerPipeHandle(path, sd, c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -476,7 +469,7 @@ func connectPipe(p *win32File) error {
err = connectNamedPipe(p.handle, &c.o) err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err) _, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED { if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err return err
} }
return nil return nil

View File

@@ -1,29 +0,0 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"unsafe"
)
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
if err != nil {
return nil, err
}
defer localFree(sdBuffer)
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
}

View File

@@ -39,30 +39,26 @@ func errnoErr(e syscall.Errno) error {
var ( var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll") modntdll = windows.NewLazySystemDLL("ntdll.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procCreateFileW = modkernel32.NewProc("CreateFileW") procCreateFileW = modkernel32.NewProc("CreateFileW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procLocalAlloc = modkernel32.NewProc("LocalAlloc") procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW") procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procLocalFree = modkernel32.NewProc("LocalFree") procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength") procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procCancelIoEx = modkernel32.NewProc("CancelIoEx") procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort") procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
) )
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) { func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -74,7 +70,7 @@ func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
return return
} }
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) { func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name) _p0, err = syscall.UTF16PtrFromString(name)
if err != nil { if err != nil {
@@ -83,10 +79,10 @@ func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances ui
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
} }
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) { func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0) handle = windows.Handle(r0)
if handle == syscall.InvalidHandle { if handle == windows.InvalidHandle {
if e1 != 0 { if e1 != 0 {
err = errnoErr(e1) err = errnoErr(e1)
} else { } else {
@@ -96,7 +92,7 @@ func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances
return return
} }
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) { func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name) _p0, err = syscall.UTF16PtrFromString(name)
if err != nil { if err != nil {
@@ -105,10 +101,10 @@ func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAtt
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile) return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
} }
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) { func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0) handle = windows.Handle(r0)
if handle == syscall.InvalidHandle { if handle == windows.InvalidHandle {
if e1 != 0 { if e1 != 0 {
err = errnoErr(e1) err = errnoErr(e1)
} else { } else {
@@ -118,7 +114,7 @@ func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityA
return return
} }
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -130,7 +126,7 @@ func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSiz
return return
} }
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0) r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -148,7 +144,7 @@ func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
return return
} }
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0) status = ntstatus(r0)
return return
@@ -174,39 +170,7 @@ func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
return return
} }
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) { func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(str)
if err != nil {
return
}
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
}
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -218,9 +182,9 @@ func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
return return
} }
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) { func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0) newport = windows.Handle(r0)
if newport == 0 { if newport == 0 {
if e1 != 0 { if e1 != 0 {
err = errnoErr(e1) err = errnoErr(e1)
@@ -231,7 +195,7 @@ func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintpt
return return
} }
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -243,7 +207,7 @@ func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr,
return return
} }
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) { func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0) r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -255,7 +219,7 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
return return
} }
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32 var _p0 uint32
if wait { if wait {
_p0 = 1 _p0 = 1

16
main.go
View File

@@ -40,31 +40,19 @@ func warning() {
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return return
} }
shouldQuit := os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G") fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and foolish. This G") fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G") fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G") fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G") fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
fmt.Fprintln(os.Stderr, "W implementation. For more information on G") fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G") fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G") fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit {
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W the advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
}
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
if shouldQuit {
os.Exit(1)
}
} }
func main() { func main() {
@@ -75,8 +63,6 @@ func main() {
warning() warning()
// parse arguments
var foreground bool var foreground bool
var interfaceName string var interfaceName string
if len(os.Args) < 2 || len(os.Args) > 3 { if len(os.Args) < 2 || len(os.Args) > 3 {

View File

@@ -37,7 +37,7 @@ func main() {
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName) tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tun.Name()
if err2 == nil { if err2 == nil {

View File

@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
for i := 0; i < packetsBurstable; i++ { for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{ Add(RatelimiterResult{
allowed: true, allowed: true,
text: "inital burst", text: "initial burst",
}) })
} }

View File

@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil { if err != nil {
return false return false
} }
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil { if err != nil {
return false return false
} }

View File

@@ -11,6 +11,7 @@ import (
"net" "net"
"os" "os"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@@ -42,6 +43,22 @@ type NativeTun struct {
var sockaddrCtlSize uintptr = 32 var sockaddrCtlSize uintptr = 32
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
}
}
return iface, err
}
return nil, err
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var ( var (
statusUp bool statusUp bool
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue continue
} }
iface, err := net.InterfaceByIndex(ifindex) iface, err := retryInterfaceByIndex(ifindex)
if err != nil { if err != nil {
tun.errors <- err tun.errors <- err
return return

View File

@@ -35,7 +35,7 @@ type NativeTun struct {
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan Event // device related events events chan Event // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was passed IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
@@ -85,7 +85,7 @@ func createNetlinkSocket() (int, error) {
} }
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))), Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {

View File

@@ -42,34 +42,11 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
defer close(tun.events) defer close(tun.events)
data := make([]byte, os.Getpagesize()) check := func() bool {
for { iface, err := net.InterfaceByIndex(tunIfindex)
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
goto retry
}
tun.errors <- err
return
}
if n < 8 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
if err != nil { if err != nil {
tun.errors <- err tun.errors <- err
return return true
} }
// Up / Down event // Up / Down event
@@ -87,6 +64,38 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
tun.events <- EventMTUUpdate tun.events <- EventMTUUpdate
} }
statusMTU = iface.MTU statusMTU = iface.MTU
return false
}
if check() {
return
}
data := make([]byte, os.Getpagesize())
for {
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
continue
}
tun.errors <- err
return
}
if n < 8 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
if check() {
return
}
} }
} }
@@ -140,7 +149,6 @@ func CreateTUN(name string, mtu int) (Device, error) {
} }
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan Event, 10), events: make(chan Event, 10),

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
@@ -19,87 +20,71 @@ import (
) )
const ( const (
packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
packetSizeMax = 0xffff // Maximum packet size spinloopRateThreshold = 800000000 / 8 // 800mbps
packetCapacity = 0x800000 // Ring capacity, 8MiB spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment
ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
) )
type packetHeader struct { type rateJuggler struct {
size uint32 current uint64
} nextByteCount uint64
nextStartTime int64
type packet struct { changing int32
packetHeader
data [packetSizeMax]byte
}
type ring struct {
head uint32
tail uint32
alertable int32
data [packetCapacity + packetTrailingSize]byte
}
type ringDescriptor struct {
send, receive struct {
size uint32
ring *ring
tailMoved windows.Handle
}
} }
type NativeTun struct { type NativeTun struct {
wt *wintun.Wintun wt *wintun.Interface
handle windows.Handle handle windows.Handle
close bool close bool
rings ringDescriptor
events chan Event events chan Event
errors chan error errors chan error
forcedMTU int forcedMTU int
rate rateJuggler
rings *wintun.RingDescriptor
writeLock sync.Mutex
} }
func packetAlign(size uint32) uint32 { const WintunPool = wintun.Pool("WireGuard")
return (size + (packetAlignment - 1)) &^ (packetAlignment - 1)
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil, mtu)
} }
// //
// CreateTUN creates a Wintun adapter with the given name. Should a Wintun // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// adapter with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
// //
func CreateTUN(ifname string) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil)
}
//
// CreateTUNWithRequestedGUID creates a Wintun adapter with the given name and
// a requested GUID. Should a Wintun adapter with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) {
var err error var err error
var wt *wintun.Wintun var wt *wintun.Interface
// Does an interface with this name already exist? // Does an interface with this name already exist?
wt, err = wintun.GetInterface(ifname) wt, err = WintunPool.GetInterface(ifname)
if err == nil { if err == nil {
// If so, we delete it, in case it has weird residual configuration. // If so, we delete it, in case it has weird residual configuration.
_, err = wt.DeleteInterface() _, err = wt.DeleteInterface()
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err) return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
} }
} else if err == windows.ERROR_ALREADY_EXISTS {
return nil, fmt.Errorf("Foreign network interface with the same name exists")
} }
wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", requestedGUID) wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to create Wintun interface: %v", err) return nil, fmt.Errorf("Error creating interface: %v", err)
} }
err = wt.SetInterfaceName(ifname) forcedMTU := 1420
if err != nil { if mtu > 0 {
wt.DeleteInterface() forcedMTU = mtu
return nil, fmt.Errorf("Unable to set name of Wintun interface: %v", err)
} }
tun := &NativeTun{ tun := &NativeTun{
@@ -107,33 +92,16 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
forcedMTU: 1500, forcedMTU: forcedMTU,
} }
tun.rings.send.size = uint32(unsafe.Sizeof(ring{})) tun.rings, err = wintun.NewRingDescriptor()
tun.rings.send.ring = &ring{}
tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error creating event: %v", err) return nil, fmt.Errorf("Error creating events: %v", err)
} }
tun.rings.receive.size = uint32(unsafe.Sizeof(ring{})) tun.handle, err = tun.wt.Register(tun.rings)
tun.rings.receive.ring = &ring{}
tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating event: %v", err)
}
tun.handle, err = tun.wt.AdapterHandle()
if err != nil {
tun.Close()
return nil, err
}
var bytesReturned uint32
err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err) return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -142,7 +110,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
return tun.wt.InterfaceName() return tun.wt.Name()
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
@@ -155,18 +123,13 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
tun.close = true tun.close = true
if tun.rings.send.tailMoved != 0 { if tun.rings.Send.TailMoved != 0 {
windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
} }
if tun.handle != windows.InvalidHandle { if tun.handle != windows.InvalidHandle {
windows.CloseHandle(tun.handle) windows.CloseHandle(tun.handle)
} }
if tun.rings.send.tailMoved != 0 { tun.rings.Close()
windows.CloseHandle(tun.rings.send.tailMoved)
}
if tun.rings.send.tailMoved != 0 {
windows.CloseHandle(tun.rings.receive.tailMoved)
}
var err error var err error
if tun.wt != nil { if tun.wt != nil {
_, err = tun.wt.DeleteInterface() _, err = tun.wt.DeleteInterface()
@@ -184,9 +147,6 @@ func (tun *NativeTun) ForceMTU(mtu int) {
tun.forcedMTU = mtu tun.forcedMTU = mtu
} }
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
@@ -200,50 +160,52 @@ retry:
return 0, os.ErrClosed return 0, os.ErrClosed
} }
buffHead := atomic.LoadUint32(&tun.rings.send.ring.head) buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
if buffHead >= packetCapacity { if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
start := time.Now() start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
var buffTail uint32 var buffTail uint32
for { for {
buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail) buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
if buffHead != buffTail { if buffHead != buffTail {
break break
} }
if tun.close { if tun.close {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
if time.Since(start) >= time.Millisecond/80 /* ~1gbit/s */ { if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE) windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
goto retry goto retry
} }
procyield(1) procyield(1)
} }
if buffTail >= packetCapacity { if buffTail >= wintun.PacketCapacity {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
buffContent := tun.rings.send.ring.wrap(buffTail - buffHead) buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
if buffContent < uint32(unsafe.Sizeof(packetHeader{})) { if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
return 0, errors.New("incomplete packet header in send ring") return 0, errors.New("incomplete packet header in send ring")
} }
packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead])) packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
if packet.size > packetSizeMax { if packet.Size > wintun.PacketSizeMax {
return 0, errors.New("packet too big in send ring") return 0, errors.New("packet too big in send ring")
} }
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size) alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
if alignedPacketSize > buffContent { if alignedPacketSize > buffContent {
return 0, errors.New("incomplete packet in send ring") return 0, errors.New("incomplete packet in send ring")
} }
copy(buff[offset:], packet.data[:packet.size]) copy(buff[offset:], packet.Data[:packet.Size])
buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize) buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
atomic.StoreUint32(&tun.rings.send.ring.head, buffHead) atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
return int(packet.size), nil tun.rate.update(uint64(packet.Size))
return int(packet.Size), nil
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
@@ -256,39 +218,58 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
} }
packetSize := uint32(len(buff) - offset) packetSize := uint32(len(buff) - offset)
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize) tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head) tun.writeLock.Lock()
if buffHead >= packetCapacity { defer tun.writeLock.Unlock()
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail) buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
if buffTail >= packetCapacity { if buffTail >= wintun.PacketCapacity {
return 0, os.ErrClosed return 0, os.ErrClosed
} }
buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment) buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
if alignedPacketSize > buffSpace { if alignedPacketSize > buffSpace {
return 0, nil // Dropping when ring is full. return 0, nil // Dropping when ring is full.
} }
packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail])) packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
packet.size = packetSize packet.Size = packetSize
copy(packet.data[:packetSize], buff[offset:]) copy(packet.Data[:packetSize], buff[offset:])
atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize)) atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 { if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
windows.SetEvent(tun.rings.receive.tailMoved) windows.SetEvent(tun.rings.Receive.TailMoved)
} }
return int(packetSize), nil return int(packetSize), nil
} }
// LUID returns Windows adapter instance ID. // LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 { func (tun *NativeTun) LUID() uint64 {
return tun.wt.LUID() return tun.wt.LUID()
} }
// wrap returns value modulo ring capacity // Version returns the version of the Wintun driver and NDIS system currently loaded.
func (rb *ring) wrap(value uint32) uint32 { func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
return value & (packetCapacity - 1) return tun.wt.Version()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
return
}
atomic.StoreInt64(&rate.nextStartTime, now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
atomic.StoreUint64(&rate.nextByteCount, 0)
atomic.StoreInt32(&rate.changing, 0)
}
} }

View File

@@ -0,0 +1,25 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
import "golang.org/x/sys/windows"
//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
var luid uint64
var guid windows.GUID
err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
if err != nil {
return nil, err
}
err = convertInterfaceLUIDToGUID(&luid, &guid)
if err != nil {
return nil, err
}
return &guid, nil
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go

View File

@@ -0,0 +1,60 @@
// Code generated by 'go generate'; DO NOT EDIT.
package iphlpapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
)
func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@@ -0,0 +1,98 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"encoding/hex"
"errors"
"fmt"
"sync"
"unsafe"
"golang.org/x/crypto/blake2s"
"golang.org/x/sys/windows"
"golang.org/x/text/unicode/norm"
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
)
var (
wintunObjectSecurityAttributes *windows.SecurityAttributes
hasInitializedNamespace bool
initializingNamespace sync.Mutex
)
func initializeNamespace() error {
initializingNamespace.Lock()
defer initializingNamespace.Unlock()
if hasInitializedNamespace {
return nil
}
sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil {
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
}
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: sd,
}
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil {
return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
}
boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
if err != nil {
return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
}
err = boundary.AddSid(sid)
if err != nil {
return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
}
for {
_, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
if err == windows.ERROR_ALREADY_EXISTS {
_, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
if err == windows.ERROR_PATH_NOT_FOUND {
continue
}
}
if err != nil {
return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
}
break
}
hasInitializedNamespace = true
return nil
}
func (pool Pool) takeNameMutex() (windows.Handle, error) {
err := initializeNamespace()
if err != nil {
return 0, err
}
const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
b2, _ := blake2s.New256(nil)
b2.Write([]byte(mutexLabel))
b2.Write(norm.NFC.Bytes([]byte(string(pool))))
mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
if err != nil {
err = fmt.Errorf("Error creating name mutex: %v", err)
return 0, err
}
event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
if err != nil {
windows.CloseHandle(mutex)
return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
}
if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
windows.CloseHandle(mutex)
return 0, errors.New("Error with event trigger of name mutex")
}
return mutex, nil
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go

View File

@@ -0,0 +1,83 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
import "golang.org/x/sys/windows"
//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
type BoundaryDescriptor windows.Handle
// CreateBoundaryDescriptor creates a boundary descriptor.
func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
name16, err := windows.UTF16PtrFromString(name)
if err != nil {
return 0, err
}
handle, err := createBoundaryDescriptor(name16, 0)
if err != nil {
return 0, err
}
return BoundaryDescriptor(handle), nil
}
// Delete deletes the specified boundary descriptor.
func (bd BoundaryDescriptor) Delete() {
deleteBoundaryDescriptor(windows.Handle(bd))
}
// AddSid adds a security identifier (SID) to the specified boundary descriptor.
func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
}
// PrivateNamespace represents a private namespace.
type PrivateNamespace windows.Handle
// CreatePrivateNamespace creates a private namespace.
func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
if err != nil {
return 0, err
}
handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
if err != nil {
return 0, err
}
return PrivateNamespace(handle), nil
}
// OpenPrivateNamespace opens a private namespace.
func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
if err != nil {
return 0, err
}
handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
if err != nil {
return 0, err
}
return PrivateNamespace(handle), nil
}
// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
type ClosePrivateNamespaceFlags uint32
const (
// PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
)
// Close closes an open namespace handle.
func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
return closePrivateNamespace(windows.Handle(pns), uint32(flags))
}

View File

@@ -0,0 +1,116 @@
// Code generated by 'go generate'; DO NOT EDIT.
package namespaceapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
)
func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
return
}
func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package nci
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go

View File

@@ -0,0 +1,28 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package nci
import "golang.org/x/sys/windows"
//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName
//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName
func SetConnectionName(guid *windows.GUID, newName string) error {
newName16, err := windows.UTF16PtrFromString(newName)
if err != nil {
return err
}
return nciSetConnectionName(guid, newName16)
}
func ConnectionName(guid *windows.GUID) (string, error) {
var name [0x400]uint16
err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil)
if err != nil {
return "", err
}
return windows.UTF16ToString(name[:]), nil
}

View File

@@ -0,0 +1,60 @@
// Code generated by 'go generate'; DO NOT EDIT.
package nci
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modnci = windows.NewLazySystemDLL("nci.dll")
procNciSetConnectionName = modnci.NewProc("NciSetConnectionName")
procNciGetConnectionName = modnci.NewProc("NciGetConnectionName")
)
func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) {
r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) {
r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@@ -1,32 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package netshell
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var (
modnetshell = windows.NewLazySystemDLL("netshell.dll")
procHrRenameConnection = modnetshell.NewProc("HrRenameConnection")
)
func HrRenameConnection(guid *windows.GUID, newName *uint16) (err error) {
err = procHrRenameConnection.Find()
if err != nil {
// Missing from servercore, so we can't presume it's always there.
return
}
ret, _, _ := syscall.Syscall(procHrRenameConnection.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
if ret != 0 {
err = syscall.Errno(ret)
}
return
}

View File

@@ -5,4 +5,4 @@
package registry package registry
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go

117
tun/wintun/ring_windows.go Normal file
View File

@@ -0,0 +1,117 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"runtime"
"unsafe"
"golang.org/x/sys/windows"
)
const (
PacketAlignment = 4 // Number of bytes packets are aligned to in rings
PacketSizeMax = 0xffff // Maximum packet size
PacketCapacity = 0x800000 // Ring capacity, 8MiB
PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
)
type PacketHeader struct {
Size uint32
}
type Packet struct {
PacketHeader
Data [PacketSizeMax]byte
}
type Ring struct {
Head uint32
Tail uint32
Alertable int32
Data [PacketCapacity + PacketTrailingSize]byte
}
type RingDescriptor struct {
Send, Receive struct {
Size uint32
Ring *Ring
TailMoved windows.Handle
}
}
// Wrap returns value modulo ring capacity
func (rb *Ring) Wrap(value uint32) uint32 {
return value & (PacketCapacity - 1)
}
// Aligns a packet size to PacketAlignment
func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
}
func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
descriptor = new(RingDescriptor)
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return
}
defer func() {
if err != nil {
descriptor.free()
descriptor = nil
}
}()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return
}
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved)
return
}
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return
}
func (descriptor *RingDescriptor) free() {
if descriptor.Send.Ring != nil {
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
descriptor.Send.Ring = nil
descriptor.Receive.Ring = nil
}
}
func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved)
descriptor.Send.TailMoved = 0
}
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Receive.TailMoved)
descriptor.Receive.TailMoved = 0
}
}
func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) {
handle, err := wintun.handle()
if err != nil {
return 0, err
}
var bytesReturned uint32
err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
if err != nil {
return 0, err
}
return handle, nil
}

View File

@@ -5,4 +5,4 @@
package setupapi package setupapi
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsetupapi_windows.go setupapi_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go

View File

@@ -57,7 +57,7 @@ type DevInfoData struct {
_ uintptr _ uintptr
} }
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass). // DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct { type DevInfoListDetailData struct {
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const. size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID ClassGUID windows.GUID

View File

@@ -15,17 +15,20 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/tun/wintun/netshell" "golang.zx2c4.com/wireguard/tun/wintun/iphlpapi"
"golang.zx2c4.com/wireguard/tun/wintun/nci"
registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry" registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry"
"golang.zx2c4.com/wireguard/tun/wintun/setupapi" "golang.zx2c4.com/wireguard/tun/wintun/setupapi"
) )
// Wintun is a handle of a Wintun adapter. type Pool string
type Wintun struct {
type Interface struct {
cfgInstanceID windows.GUID cfgInstanceID windows.GUID
devInstanceID string devInstanceID string
luidIndex uint32 luidIndex uint32
ifType uint32 ifType uint32
pool Pool
} }
var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
@@ -37,9 +40,9 @@ const (
) )
// makeWintun creates a Wintun interface handle and populates it from the device's registry key. // makeWintun creates a Wintun interface handle and populates it from the device's registry key.
func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (*Wintun, error) { func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE) key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil { if err != nil {
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err) return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
} }
@@ -69,30 +72,48 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err) return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
} }
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData) instanceID, err := devInfo.DeviceInstanceID(devInfoData)
if err != nil { if err != nil {
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err) return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
} }
return &Wintun{ return &Interface{
cfgInstanceID: ifid, cfgInstanceID: ifid,
devInstanceID: instanceID, devInstanceID: instanceID,
luidIndex: uint32(luidIdx), luidIndex: uint32(luidIdx),
ifType: uint32(ifType), ifType: uint32(ifType),
pool: pool,
}, nil }, nil
} }
func removeNumberedSuffix(ifname string) string {
removed := strings.TrimRight(ifname, "0123456789")
if removed != ifname && len(removed) > 1 && removed[len(removed)-1] == ' ' {
return removed[:len(removed)-1]
}
return ifname
}
// GetInterface finds a Wintun interface by its name. This function returns // GetInterface finds a Wintun interface by its name. This function returns
// the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If // the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If
// the interface is found but not a Wintun-class, this function returns // the interface is found but not a Wintun-class or a member of the pool,
// windows.ERROR_ALREADY_EXISTS. // this function returns windows.ERROR_ALREADY_EXISTS.
func GetInterface(ifname string) (*Wintun, error) { func (pool Pool) GetInterface(ifname string) (*Interface, error) {
mutex, err := pool.takeNameMutex()
if err != nil {
return nil, err
}
defer func() {
windows.ReleaseMutex(mutex)
windows.CloseHandle(mutex)
}()
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err) return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
} }
defer devInfoList.Close() defer devInfo.Close()
// Windows requires each interface to have a different name. When // Windows requires each interface to have a different name. When
// enforcing this, Windows treats interface names case-insensitive. If an // enforcing this, Windows treats interface names case-insensitive. If an
@@ -102,7 +123,7 @@ func GetInterface(ifname string) (*Wintun, error) {
ifname = strings.ToLower(ifname) ifname = strings.ToLower(ifname)
for index := 0; ; index++ { for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index) devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -110,26 +131,37 @@ func GetInterface(ifname string) (*Wintun, error) {
continue continue
} }
wintun, err := makeWintun(devInfoList, deviceData) // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
continue continue
} }
// TODO: is there a better way than comparing ifnames? // TODO: is there a better way than comparing ifnames?
ifname2, err := wintun.InterfaceName() ifname2, err := wintun.Name()
if err != nil { if err != nil {
continue continue
} }
ifname2 = strings.ToLower(ifname2)
ifname3 := removeNumberedSuffix(ifname2)
if ifname == strings.ToLower(ifname2) { if ifname == ifname2 || ifname == ifname3 {
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil { if err != nil {
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
for index := 0; ; index++ { for index := 0; ; index++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -138,12 +170,20 @@ func GetInterface(ifname string) (*Wintun, error) {
} }
// Get driver info details. // Get driver info details.
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil {
return nil, err
}
if !isMember {
return nil, windows.ERROR_ALREADY_EXISTS
}
return wintun, nil return wintun, nil
} }
} }
@@ -156,24 +196,31 @@ func GetInterface(ifname string) (*Wintun, error) {
return nil, windows.ERROR_OBJECT_NOT_FOUND return nil, windows.ERROR_OBJECT_NOT_FOUND
} }
// CreateInterface creates a Wintun interface. description is a string that // CreateInterface creates a Wintun interface. ifname is the requested name of
// supplies the text description of the device. The description is optional // the interface, while requestedGUID is the GUID of the created network
// and can be "". requestedGUID is the GUID of the created network interface, // interface, which then influences NLA generation deterministically. If it is
// which then influences NLA generation deterministically. If it is set to nil, // set to nil, the GUID is chosen by the system at random, and hence a new NLA
// the GUID is chosen by the system at random, and hence a new NLA entry is // entry is created for each new interface. It is called "requested" GUID
// created for each new interface. It is called "requested" GUID because the // because the API it uses is completely undocumented, and so there could be minor
// API it uses is completely undocumented, and so there could be minor
// interesting complications with its usage. This function returns the network // interesting complications with its usage. This function returns the network
// interface ID and a flag if reboot is required. // interface ID and a flag if reboot is required.
// func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wintun *Interface, rebootRequired bool, err error) {
func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *Wintun, rebootRequired bool, err error) { mutex, err := pool.takeNameMutex()
if err != nil {
return
}
defer func() {
windows.ReleaseMutex(mutex)
windows.CloseHandle(mutex)
}()
// Create an empty device info set for network adapter device class. // Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err) err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
return return
} }
defer devInfoList.Close() defer devInfo.Close()
// Get the device class name from GUID. // Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
@@ -183,43 +230,44 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
} }
// Create a new device info element and add it to the device info set. // Create a new device info element and add it to the device info set.
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, 0, setupapi.DICD_GENERATE_ID) deviceTypeName := pool.deviceTypeName()
devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err) err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
return return
} }
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("Setting quiet installation failed: %v", err) err = fmt.Errorf("Setting quiet installation failed: %v", err)
return return
} }
// Set a device information element as the selected member of a device information set. // Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData) err = devInfo.SetSelectedDevice(devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err) err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
return return
} }
// Set Plug&Play device hardware ID property. // Set Plug&Play device hardware ID property.
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID) err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err) err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
return return
} }
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
return return
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
driverDate := windows.Filetime{} driverDate := windows.Filetime{}
driverVersion := uint64(0) driverVersion := uint64(0)
for index := 0; ; index++ { // TODO: This loop takes ~600ms for index := 0; ; index++ { // TODO: This loop takes ~600ms
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -229,13 +277,13 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match. // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
if driverData.IsNewer(driverDate, driverVersion) { if driverData.IsNewer(driverDate, driverVersion) {
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
err := devInfoList.SetSelectedDriver(deviceData, driverData) err := devInfo.SetSelectedDriver(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
@@ -251,42 +299,6 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
return return
} }
// Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return
}
// Register device co-installers if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData)
var key registry.Key
const pollTimeout = time.Millisecond * 50
for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
if i != 0 {
time.Sleep(pollTimeout)
}
key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil {
break
}
}
if err != nil {
err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
return
}
defer key.Close()
if requestedGUID != nil {
err = key.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
if err != nil {
err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
return
}
}
// Install interfaces if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
defer func() { defer func() {
if err != nil { if err != nil {
// The interface failed to install, or the interface ID was unobtainable. Clean-up. // The interface failed to install, or the interface ID was unobtainable. Clean-up.
@@ -296,10 +308,10 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil { if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
// Call appropriate class installer. // Call appropriate class installer.
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil { if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData) rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
} }
} }
@@ -307,64 +319,85 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
} }
}() }()
// Call appropriate class installer.
err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return
}
// Register device co-installers if any. (Ignore errors)
devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
var netDevRegKey registry.Key
const pollTimeout = time.Millisecond * 50
for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
if i != 0 {
time.Sleep(pollTimeout)
}
netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil {
break
}
}
if err != nil {
err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
return
}
defer netDevRegKey.Close()
if requestedGUID != nil {
err = netDevRegKey.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
if err != nil {
err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
return
}
}
// Install interfaces if any. (Ignore errors)
devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
// Install the device. // Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
return return
} }
rebootRequired = checkReboot(devInfoList, deviceData) rebootRequired = checkReboot(devInfo, devInfoData)
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
if err != nil {
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
return
}
// DIF_INSTALLDEVICE returns almost immediately, while the device installation // DIF_INSTALLDEVICE returns almost immediately, while the device installation
// continues in the background. It might take a while, before all registry // continues in the background. It might take a while, before all registry
// keys and values are populated. // keys and values are populated.
_, err = registryEx.GetStringValueWait(key, "NetCfgInstanceId", waitForRegistryTimeout) _, err = registryEx.GetStringValueWait(netDevRegKey, "NetCfgInstanceId", waitForRegistryTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err) err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err)
return return
} }
_, err = registryEx.GetIntegerValueWait(key, "NetLuidIndex", waitForRegistryTimeout) _, err = registryEx.GetIntegerValueWait(netDevRegKey, "NetLuidIndex", waitForRegistryTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err) err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err)
return return
} }
_, err = registryEx.GetIntegerValueWait(key, "*IfType", waitForRegistryTimeout) _, err = registryEx.GetIntegerValueWait(netDevRegKey, "*IfType", waitForRegistryTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err) err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err)
return return
} }
// Get network interface. // Get network interface.
wintun, err = makeWintun(devInfoList, deviceData) wintun, err = makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err) err = fmt.Errorf("makeWintun failed: %v", err)
return return
} }
// Wait for network registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.netRegKeyName(),
registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err)
return
}
defer key.Close()
_, err = registryEx.GetStringValueWait(key, "Name", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(Name) failed: %v", err)
return
}
_, err = registryEx.GetStringValueWait(key, "PnPInstanceId", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(PnPInstanceId) failed: %v", err)
return
}
// Wait for TCP/IP adapter registry key to emerge and populate. // Wait for TCP/IP adapter registry key to emerge and populate.
key, err = registryEx.OpenKeyWait( tcpipAdapterRegKey, err := registryEx.OpenKeyWait(
registry.LOCAL_MACHINE, registry.LOCAL_MACHINE,
wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout) waitForRegistryTimeout)
@@ -372,8 +405,8 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err) err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err)
return return
} }
defer key.Close() defer tcpipAdapterRegKey.Close()
_, err = registryEx.GetStringValueWait(key, "IpConfig", waitForRegistryTimeout) _, err = registryEx.GetStringValueWait(tcpipAdapterRegKey, "IpConfig", waitForRegistryTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err) err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err)
return return
@@ -386,28 +419,23 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
} }
// Wait for TCP/IP interface registry key to emerge. // Wait for TCP/IP interface registry key to emerge.
key, err = registryEx.OpenKeyWait( tcpipInterfaceRegKey, err := registryEx.OpenKeyWait(
registry.LOCAL_MACHINE, registry.LOCAL_MACHINE,
tcpipInterfaceRegKeyName, registry.QUERY_VALUE, tcpipInterfaceRegKeyName, registry.QUERY_VALUE|registry.SET_VALUE,
waitForRegistryTimeout) waitForRegistryTimeout)
if err != nil { if err != nil {
err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err) err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err)
return return
} }
key.Close() defer tcpipInterfaceRegKey.Close()
//
// All the registry keys and values we're relying on are present now.
//
// Disable dead gateway detection on our interface. // Disable dead gateway detection on our interface.
key, err = registry.OpenKey(registry.LOCAL_MACHINE, tcpipInterfaceRegKeyName, registry.SET_VALUE) tcpipInterfaceRegKey.SetDWordValue("EnableDeadGWDetect", 0)
err = wintun.SetName(ifname)
if err != nil { if err != nil {
err = fmt.Errorf("Error opening interface-specific TCP/IP network registry key: %v", err) err = fmt.Errorf("Unable to set name of Wintun interface: %v", err)
return return
} }
key.SetDWordValue("EnableDeadGWDetect", 0)
key.Close()
return return
} }
@@ -415,15 +443,15 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
// DeleteInterface deletes a Wintun interface. This function succeeds // DeleteInterface deletes a Wintun interface. This function succeeds
// if the interface was not found. It returns a bool indicating whether // if the interface was not found. It returns a bool indicating whether
// a reboot is required. // a reboot is required.
func (wintun *Wintun) DeleteInterface() (rebootRequired bool, err error) { func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
devInfoList, deviceData, err := wintun.deviceData() devInfo, devInfoData, err := wintun.devInfoData()
if err == windows.ERROR_OBJECT_NOT_FOUND { if err == windows.ERROR_OBJECT_NOT_FOUND {
return false, nil return false, nil
} }
if err != nil { if err != nil {
return false, err return false, err
} }
defer devInfoList.Close() defer devInfo.Close()
// Remove the device. // Remove the device.
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
@@ -432,32 +460,42 @@ func (wintun *Wintun) DeleteInterface() (rebootRequired bool, err error) {
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err) return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
} }
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil { if err != nil {
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err) return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
} }
return checkReboot(devInfoList, deviceData), nil return checkReboot(devInfo, devInfoData), nil
} }
// DeleteAllInterfaces deletes all Wintun interfaces, and returns which // DeleteMatchingInterfaces deletes all Wintun interfaces, which match
// ones it deleted, whether a reboot is required after, and which errors // given criteria, and returns which ones it deleted, whether a reboot
// occurred during the process. // is required after, and which errors occurred during the process.
func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) { func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) {
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") mutex, err := pool.takeNameMutex()
if err != nil {
errors = append(errors, err)
return
}
defer func() {
windows.ReleaseMutex(mutex)
windows.CloseHandle(mutex)
}()
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())} return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
} }
defer devInfoList.Close() defer devInfo.Close()
for i := 0; ; i++ { for i := 0; ; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i) devInfoData, err := devInfo.EnumDeviceInfo(i)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -465,22 +503,31 @@ func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool
continue continue
} }
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil { if err != nil {
continue continue
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
continue
}
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
isWintun := false isWintun := false
for j := 0; ; j++ { for j := 0; ; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
@@ -493,36 +540,71 @@ func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool
continue continue
} }
err = setQuietInstall(devInfoList, deviceData) isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil {
errors = append(errors, err)
continue
}
if !isMember {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil {
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
continue
}
if !matches(wintun) {
continue
}
err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
inst := deviceData.DevInst inst := devInfoData.DevInst
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
} }
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData) rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
deviceInstancesDeleted = append(deviceInstancesDeleted, inst) deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
} }
return return
} }
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
}
deviceDesc, _ := deviceDescVal.(string)
friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
}
friendlyName, _ := friendlyNameVal.(string)
deviceTypeName := pool.deviceTypeName()
return friendlyName == deviceTypeName || deviceDesc == deviceTypeName ||
removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil
}
// checkReboot checks device install parameters if a system reboot is required. // checkReboot checks device install parameters if a system reboot is required.
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool { func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil { if err != nil {
return false return false
} }
@@ -531,57 +613,117 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
} }
// setQuietInstall sets device install parameters for a quiet installation // setQuietInstall sets device install parameters for a quiet installation
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error { func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil { if err != nil {
return err return err
} }
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams) return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
} }
// InterfaceName returns the name of the Wintun interface. // deviceTypeName returns pool-specific device type name.
func (wintun *Wintun) InterfaceName() (string, error) { func (pool Pool) deviceTypeName() string {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.QUERY_VALUE) return fmt.Sprintf("%s Tunnel", pool)
if err != nil { }
return "", fmt.Errorf("Network-specific registry key open failed: %v", err)
// Name returns the name of the Wintun interface.
func (wintun *Interface) Name() (string, error) {
return nci.ConnectionName(&wintun.cfgInstanceID)
}
// SetName sets name of the Wintun interface.
func (wintun *Interface) SetName(ifname string) error {
const maxSuffix = 1000
availableIfname := ifname
for i := 0; ; i++ {
err := nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
if err == windows.ERROR_DUP_NAME {
duplicateGuid, err2 := iphlpapi.InterfaceGUIDFromAlias(availableIfname)
if err2 == nil {
for j := 0; j < maxSuffix; j++ {
proposal := fmt.Sprintf("%s %d", ifname, j+1)
if proposal == availableIfname {
continue
}
err2 = nci.SetConnectionName(duplicateGuid, proposal)
if err2 == windows.ERROR_DUP_NAME {
continue
}
if err2 == nil {
err = nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
if err == nil {
break
}
}
break
}
}
}
if err == nil {
break
}
if i > maxSuffix || err != windows.ERROR_DUP_NAME {
return fmt.Errorf("NciSetConnectionName failed: %v", err)
}
availableIfname = fmt.Sprintf("%s %d", ifname, i+1)
} }
defer key.Close()
// Get the interface name. // TODO: This should use NetSetup2 so that it doesn't get unset.
return registryEx.GetStringValue(key, "Name") deviceRegKey, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.deviceRegKeyName(), registry.SET_VALUE)
}
// SetInterfaceName sets name of the Wintun interface.
func (wintun *Wintun) SetInterfaceName(ifname string) error {
// We have to tell the various runtime COM services about the new name too. We ignore the
// error because netshell isn't available on servercore.
// TODO: netsh.exe falls back to NciSetConnection in this case. If somebody complains, maybe
// we should do the same.
netshell.HrRenameConnection(&wintun.cfgInstanceID, windows.StringToUTF16Ptr(ifname))
// Set the interface name. The above line should have done this too, but in case it failed, we force it.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.SET_VALUE)
if err != nil { if err != nil {
return fmt.Errorf("Network-specific registry key open failed: %v", err) return fmt.Errorf("Device-level registry key open failed: %v", err)
} }
defer key.Close() defer deviceRegKey.Close()
return key.SetStringValue("Name", ifname) err = deviceRegKey.SetStringValue("FriendlyName", wintun.pool.deviceTypeName())
} if err != nil {
return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err)
// netRegKeyName returns the interface-specific network registry key name. }
func (wintun *Wintun) netRegKeyName() string { return nil
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", deviceClassNetGUID, wintun.cfgInstanceID)
} }
// tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name. // tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name.
func (wintun *Wintun) tcpipAdapterRegKeyName() string { func (wintun *Interface) tcpipAdapterRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID) return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
} }
// deviceRegKeyName returns the device-level registry key name.
func (wintun *Interface) deviceRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
}
// Version returns the version of the Wintun driver and NDIS system currently loaded.
func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE)
if err != nil {
return
}
defer key.Close()
driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion")
if err != nil {
return
}
driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion")
if err != nil {
return
}
ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion")
if err != nil {
return
}
ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion")
if err != nil {
return
}
driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor)
ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor)
return
}
// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name. // tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
func (wintun *Wintun) tcpipInterfaceRegKeyName() (path string, err error) { func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE) key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
if err != nil { if err != nil {
return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err) return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
@@ -597,18 +739,18 @@ func (wintun *Wintun) tcpipInterfaceRegKeyName() (path string, err error) {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
} }
// deviceData returns TUN device info list handle and interface device info // devInfoData returns TUN device info list handle and interface device info
// data. The device info list handle must be closed after use. In case the // data. The device info list handle must be closed after use. In case the
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned. // device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
func (wintun *Wintun) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) { func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error()) return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
} }
for index := 0; ; index++ { for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index) devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -618,44 +760,44 @@ func (wintun *Wintun) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, err
// Get interface ID. // Get interface ID.
// TODO: Store some ID in the Wintun object such that this call isn't required. // TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData) wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
if err != nil { if err != nil {
continue continue
} }
if wintun.cfgInstanceID == wintun2.cfgInstanceID { if wintun.cfgInstanceID == wintun2.cfgInstanceID {
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
devInfoList.Close() devInfo.Close()
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err) return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
} }
return devInfoList, deviceData, nil return devInfo, devInfoData, nil
} }
} }
devInfoList.Close() devInfo.Close()
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
} }
// AdapterHandle returns a handle to the adapter device object. // handle returns a handle to the interface device object.
func (wintun *Wintun) AdapterHandle() (windows.Handle, error) { func (wintun *Interface) handle() (windows.Handle, error) {
interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT) interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
if err != nil { if err != nil {
return windows.InvalidHandle, err return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err)
} }
handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0) handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
if err != nil { if err != nil {
return windows.InvalidHandle, fmt.Errorf("Open NDIS device failed: %v", err) return windows.InvalidHandle, fmt.Errorf("Error opening NDIS device: %v", err)
} }
return handle, nil return handle, nil
} }
// GUID returns the GUID of the interface. // GUID returns the GUID of the interface.
func (wintun *Wintun) GUID() windows.GUID { func (wintun *Interface) GUID() windows.GUID {
return wintun.cfgInstanceID return wintun.cfgInstanceID
} }
// LUID returns the LUID of the interface. // LUID returns the LUID of the interface.
func (wintun *Wintun) LUID() uint64 { func (wintun *Interface) LUID() uint64 {
return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48) return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48)
} }