64 Commits

Author SHA1 Message Date
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
05b03c6750 version: bump snapshot 2020-01-21 16:27:19 +01:00
Jason A. Donenfeld
caebdfe9d0 tun: darwin: ignore ENOMEM errors
Coauthored-by: Andrej Mihajlov <and@mullvad.net>
2020-01-15 13:39:37 -05:00
Jason A. Donenfeld
4fa2ea6a2d tun: windows: serialize write calls 2020-01-07 11:40:45 -05:00
Jason A. Donenfeld
89dd065e53 README: update repo urls 2019-12-30 11:53:39 +01:00
Jason A. Donenfeld
ddfad453cf device: SendmsgN mutates the input sockaddr
So we take a new granular lock to prevent concurrent writes from
racing.

WARNING: DATA RACE
Write at 0x00c0011f2740 by goroutine 27:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:318
+0x4b8

Previous write at 0x00c0011f2740 by goroutine 386:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.expiredRetransmitHandshake()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:110
+0x40c
  golang.zx2c4.com/wireguard/device.(*Peer).NewTimer.func1()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:42
+0xd8

Goroutine 27 (running) created at:
  golang.zx2c4.com/wireguard/device.NewDevice()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/device.go:322
+0x5e8
  main.main()
      /go/src/x/main.go:102 +0x58e

Goroutine 386 (finished) created at:
  time.goFunc()
      /usr/local/go/src/time/sleep.go:168 +0x51

Reported-by: Ben Burkert <ben@benburkert.com>
2019-11-28 11:11:13 +01:00
Jason A. Donenfeld
2b242f9393 wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially
well.
2019-11-22 13:13:55 +01:00
Jason A. Donenfeld
4cdf805b29 constants: recalculate rekey max based on a one minute flood
Discussed-with: Mathias Hall-Andersen <mathias@hall-andersen.dk>
2019-10-30 14:29:32 +01:00
Jonathan Tooker
f7d0edd2ec global: fix a few typos courtesy of codespell
Signed-off-by: Jonathan Tooker <jonathan.tooker@netprotect.com>
2019-10-22 11:51:25 +02:00
Jason A. Donenfeld
ffffbbcc8a device: allow blackholing sockets 2019-10-21 13:29:57 +02:00
Jason A. Donenfeld
47b02c618b device: remove dead error reporting code 2019-10-21 11:46:54 +02:00
Jason A. Donenfeld
fd23c66fcd namespaceapi: remove tasteless comment 2019-10-21 09:02:29 +02:00
Jason A. Donenfeld
ae492d1b35 device: recheck counters while holding write lock 2019-10-17 15:43:06 +02:00
Jason A. Donenfeld
95fbfccf60 wintun: normalize variable names for their types 2019-10-17 15:30:56 +02:00
Avery Pennarun
c85e4a410f wintun: quickly ignore non-Wintun devices
Some devices take ~2 seconds to enumerate on Windows if we try to get
their instance name.  The hardware id property, on the other hand,
is available right away.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: inlined this to where it makes sense, reused setupapi const]
2019-10-17 15:19:20 +02:00
Avery Pennarun
1b6c8ddbe8 tun: match windows CreateTUN signature to the Linux variant
Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: fix default value]
2019-10-17 15:19:20 +02:00
Avery Pennarun
0abb6b668c rwcancel: handle EINTR and EAGAIN in unixSelect()
On my Chromebook (Linux 4.19.44 in a VM) and on an AWS EC2
machine, select() was sometimes returning EINTR. This is
harmless and just means you should try again. So let's try
again.

This eliminates a problem where the tunnel fails to come up
correctly and the program needs to be restarted.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
2019-10-17 15:19:17 +02:00
David Crawshaw
540d01e54a device: test packets between two fake devices
Signed-off-by: David Crawshaw <crawshaw@tailscale.io>
2019-10-16 11:38:28 +02:00
Jason A. Donenfeld
f2ea85e9f9 version: bump snapshot 2019-10-12 22:34:10 +02:00
Jason A. Donenfeld
222f0f8000 Makefile: remove v prefix 2019-10-08 16:48:18 +02:00
Jason A. Donenfeld
1f146a5e7a wintun: expose version 2019-10-08 09:58:58 +02:00
Jason A. Donenfeld
f2501aa6c8 uapi: allow preventing creation of new peers when updating
This enables race-free updates for wg-dynamic and similar tools.

Suggested-by: Thomas Gschwantner <tharre3@gmail.com>
2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
cb8d01f58a mod: bump versions 2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
01f8ef4e84 winpipe: use x/sys/windows instead of syscall 2019-09-16 23:39:16 -06:00
Jason A. Donenfeld
70f6c42556 wintun: use correct length for security attributes 2019-09-16 19:38:33 -06:00
Jason A. Donenfeld
bb0b2514c0 tun: windows: unify error message format 2019-09-08 13:52:44 -05:00
Jason A. Donenfeld
7c97fdb1e3 version: bump snapshot 2019-09-08 10:56:55 -05:00
Jason A. Donenfeld
84b5a4d83d main: simplify warnings 2019-09-08 10:56:00 -05:00
Jason A. Donenfeld
4cd06c0925 tun: openbsd: check for interface already being up
In some cases, we operate on an already-up interface, or the user brings
up the interface before we start monitoring. For those situations, we
should first check if the interface is already up.

This still technically races between the initial check and the start of
the route loop, but fixing that is a bit ugly and probably not worth it
at the moment.

Reported-by: Theo Buehler <tb@theobuehler.org>
2019-09-07 00:13:23 -05:00
Jason A. Donenfeld
d12eb91f9a namespaceapi: AddSIDToBoundaryDescriptor modifies the handle 2019-09-05 21:48:21 -06:00
Jason A. Donenfeld
73d3bd9cd5 wintun: take mutex first always
This prevents an ABA deadlock with setupapi's internal locks.
2019-09-01 21:32:28 -06:00
Jason A. Donenfeld
f3dba4c194 wintun: consider abandoned mutexes as released 2019-09-01 21:25:47 -06:00
Jason A. Donenfeld
7937840f96 ipc: windows: use protected prefix 2019-08-31 07:48:42 -06:00
Jason A. Donenfeld
e4b957183c winpipe: enforce ownership of client connection 2019-08-30 13:21:47 -06:00
Jason A. Donenfeld
950ca2ba8c wintun: put mutex into private namespace 2019-08-30 11:03:21 -06:00
Jason A. Donenfeld
df2bf34373 namespaceapi: fix mistake 2019-08-30 09:59:36 -06:00
Simon Rozman
a12b765784 namespaceapi: initial version
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-30 15:34:17 +02:00
Jason A. Donenfeld
14df9c3e75 wintun: take mutex so that deletion uses the right name 2019-08-30 15:34:17 +02:00
Jason A. Donenfeld
353f0956bc wintun: move ring constants into module 2019-08-29 13:22:17 -06:00
Jason A. Donenfeld
fa7763c268 wintun: delete all interfaces is not used anymore 2019-08-29 12:22:15 -06:00
Jason A. Donenfeld
d94bae8348 wintun: Wintun->Interface 2019-08-29 12:20:40 -06:00
Jason A. Donenfeld
7689d09336 wintun: keep reference to pool in wintun object 2019-08-29 12:13:16 -06:00
Simon Rozman
69c26dc258 wintun: introduce adapter pools
This makes wintun package reusable for non-WireGuard applications.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-29 18:00:44 +02:00
Jason A. Donenfeld
e862131d3c wintun: simplify rename logic 2019-08-28 19:31:20 -06:00
Jason A. Donenfeld
da28a3e9f3 wintun: give better errors when ndis interface listing fails 2019-08-28 08:39:26 -06:00
Jason A. Donenfeld
3bf3322b2c wintun: also check for numbered suffix and friendly name 2019-08-28 08:08:07 -06:00
Simon Rozman
7305b4ce93 wintun: upgrade deleting all interfaces and make it reusable
DeleteAllInterfaces() didn't check if SPDRP_DEVICEDESC == "WireGuard
Tunnel". It deleted _all_ Wintun adapters, not just WireGuard's.

Furthermore, the DeleteAllInterfaces() was upgraded into a new function
called DeleteMatchingInterfaces() for selectively deletion. This will
be used by WireGuard to clean stale Wintun adapters.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-28 11:39:01 +02:00
Jason A. Donenfeld
26fb615b11 wintun: cleanup earlier 2019-08-27 11:59:15 -06:00
Jason A. Donenfeld
7fbb24afaa wintun: rename duplicate adapters instead of ourselves 2019-08-27 11:59:15 -06:00
Jason A. Donenfeld
d9008ac35c wintun: match suffix numbers 2019-08-26 14:46:43 -06:00
Jason A. Donenfeld
f8198c0428 device: getsockname on linux to determine port
It turns out Go isn't passing the pointer properly so we wound up with a
zero port every time.
2019-08-25 12:45:13 -06:00
Jason A. Donenfeld
0c540ad60e wintun: make description consistent across fields 2019-08-24 12:29:17 +02:00
Jason A. Donenfeld
3cedc22d7b wintun: try multiple names until one isn't a duplicate 2019-08-22 08:52:59 +02:00
Jason A. Donenfeld
68fea631d8 wintun: use nci.dll directly instead of buggy netshell 2019-08-21 09:16:12 +02:00
Jason A. Donenfeld
ef23100a4f wintun: set friendly a bit better
This is still wrong, but NETSETUPPKEY_Driver_FriendlyName seems a bit
tricky to use.
2019-08-20 16:06:55 +02:00
Jason A. Donenfeld
eb786cd7c1 wintun: also set friendly name after setting interface name 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
333de75370 wintun: defer requires unique variable 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
d20459dc69 wintun: set adapter description name 2019-08-19 10:12:50 +02:00
Jason A. Donenfeld
01786286c1 tun: windows: don't spin unless we really need it 2019-08-19 10:12:50 +02:00
47 changed files with 1625 additions and 775 deletions

View File

@@ -1,30 +1,16 @@
PREFIX ?= /usr
DESTDIR ?=
BINDIR ?= $(PREFIX)/bin
export GOPATH ?= $(CURDIR)/.gopath
export GO111MODULE := on
all: generate-version-and-build
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
else ifeq ($(shell go env GOOS),linux)
ireallywantobuildon_linux.go:
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"
clean-ireallywantobuildon_linux.go:
@rm -f ireallywantobuildon_linux.go
.PHONY: clean-ireallywantobuildon_linux.go
clean: clean-ireallywantobuildon_linux.go
wireguard-go: ireallywantobuildon_linux.go
endif
MAKEFLAGS += --no-print-directory
generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > device/version.go && \
git update-index --assume-unchanged device/version.go || true

View File

@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0
```
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.

View File

@@ -18,7 +18,7 @@ const (
sockoptIPV6_UNICAST_IF = 31
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
@@ -41,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
if err != nil {
return err
}
device.net.bind.(*nativeBind).blackhole4 = blackhole
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil {
return err
@@ -58,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
if err != nil {
return err
}
device.net.bind.(*nativeBind).blackhole6 = blackhole
return nil
}

View File

@@ -23,6 +23,8 @@ import (
type nativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
}
type NativeEndpoint net.UDPAddr
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole4 {
return nil
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err

View File

@@ -7,7 +7,7 @@
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
@@ -43,6 +43,7 @@ type IPv6Source struct {
}
type NativeEndpoint struct {
sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
@@ -117,7 +118,7 @@ func createNetlinkRouteSocket() (int, error) {
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful
// attempt ipv6 bind, update port if successful
bind.sock6, newPort, err = create6(port)
if err != nil {
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
port = newPort
}
// attempt ipv4 bind, update port if succesful
// attempt ipv4 bind, update port if successful
bind.sock4, newPort, err = create4(port)
if err != nil {
@@ -391,6 +392,11 @@ func create4(port uint16) (int, uint16, error) {
return FD_ERR, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet4).Port
}
return fd, uint16(addr.Port), err
}
@@ -450,6 +456,11 @@ func create6(port uint16) (int, uint16, error) {
return FD_ERR, 0, err
}
sa, err := unix.Getsockname(fd)
if err == nil {
addr.Port = sa.(*unix.SockaddrInet6).Port
}
return fd, uint16(addr.Port), err
}
@@ -472,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
},
}
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
if err == nil {
return nil
@@ -483,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
}
return err
@@ -512,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0
}
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
if err == nil {
return nil
@@ -523,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
}
return err
@@ -531,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
@@ -563,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr

View File

@@ -12,7 +12,7 @@ import (
/* Specification constants */
const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90

View File

@@ -236,24 +236,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap {
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
}
}
for _, peer := range lockedPeers {
peer.handshake.mutex.RUnlock()

View File

@@ -5,54 +5,212 @@
package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import (
"bufio"
"bytes"
"encoding/binary"
"io"
"net"
"os"
"strings"
"testing"
"time"
"golang.zx2c4.com/wireguard/tun"
)
func TestDevice(t *testing.T) {
// prepare tun devices for generating traffic
tun1 := newDummyTUN("tun1")
tun2 := newDummyTUN("tun2")
_ = tun1
_ = tun2
// prepare endpoints
end1, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
func TestTwoDevicePing(t *testing.T) {
// TODO(crawshaw): pick unused ports on localhost
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
listen_port=53511
replace_peers=true
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
tun1 := NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
dev1.Up()
defer dev1.Close()
if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
listen_port=53512
replace_peers=true
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
t.Fatal(err)
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("ping did not transit")
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("return ping did not transit")
}
})
}
func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
}
func assertNil(t *testing.T, err error) {
@@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View File

@@ -39,13 +39,13 @@ const (
)
const (
MessageInitiationSize = 148 // size of handshake initation message
MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
)
const (
@@ -154,6 +154,7 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,31 +172,22 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) {
return nil, errZeroECDHResult
}
var key [chacha20poly1305.KeySize]byte
KDF2(
&handshake.chainKey,
&key,
@@ -209,23 +196,29 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
// encrypt timestamp
timestamp := tai64n.Now()
func() {
var key [chacha20poly1305.KeySize]byte
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
timestamp := tai64n.Now()
aead, _ = chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
// assign index
device.indexTable.Delete(handshake.localIndex)
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil {
return nil
}
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity
var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
@@ -315,8 +309,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
if timestamp.After(handshake.lastTimestamp) {
handshake.lastTimestamp = timestamp
handshake.lastInitiationConsumption = time.Now()
}
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock()

View File

@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return
}
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:])
}

View File

@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk
handshake.mutex.Unlock()
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil
// conditionally add
// add
if !ssIsZero {
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer

View File

@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil {
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
return nil
}
func (peer *Peer) keepKeyFreshSending() {
@@ -518,11 +515,19 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu
var paddedSize int
if mtu == 0 {
paddedSize = (len(elem.packet) + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
} else {
lastUnit := len(elem.packet)
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
}
for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0)
}

View File

@@ -113,6 +113,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
var peer *Peer
dummy := false
createdNewPeer := false
deviceConfig := true
for scanner.Scan() {
@@ -137,7 +138,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromHex(value)
err := sk.FromMaybeZeroHex(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}
@@ -237,7 +238,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
peer = device.LookupPeer(publicKey)
}
if peer == nil {
createdNewPeer = peer == nil
if createdNewPeer {
peer, err = device.NewPeer(publicKey)
if err != nil {
logError.Println("Failed to create new peer:", err)
@@ -251,6 +253,20 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
}
}
case "update_only":
// allow disabling of creation
if value != "true" {
logError.Println("Failed to set update only, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if createdNewPeer && !dummy {
device.RemovePeer(peer.handshake.remoteStatic)
peer = &Peer{}
dummy = true
}
case "remove":
// remove currently selected peer from device

View File

@@ -1,3 +1,3 @@
package device
const WireGuardGoVersion = "0.0.20190805"
const WireGuardGoVersion = "0.0.20200320"

View File

@@ -1,15 +0,0 @@
// +build !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package main
const DoNotUseThisProgramOnLinux = UseTheKernelModuleInstead
// --------------------------------------------------------
// Do not use this on Linux. Instead use the kernel module.
// See wireguard.com/install for more information.
// --------------------------------------------------------

7
go.mod
View File

@@ -3,7 +3,8 @@ module golang.zx2c4.com/wireguard
go 1.12
require (
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
golang.org/x/text v0.3.2
)

15
go.sum
View File

@@ -1,11 +1,14 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 h1:ZpKuNIejY8P0ExLOVyKhb0WsgG8UdvHXe6TWjY7eL6k=
golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 h1:dfGZHvZk057jK2MCeWus/TowKpJ8y4AmooUzdBSR9GU=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f h1:dHNZYIYdq2QuU6w73vZ/DzesPbVlZVYZTtTZmrnsbQ8=
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

View File

@@ -8,6 +8,8 @@ package ipc
import (
"net"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/winpipe"
)
@@ -47,14 +49,22 @@ func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr()
}
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
func init() {
var err error
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil {
panic(err)
}
}
func UAPIListen(name string) (net.Listener, error) {
config := winpipe.PipeConfig{
SecurityDescriptor: UAPISecurityDescriptor,
}
listener, err := winpipe.ListenPipe("\\\\.\\pipe\\WireGuard\\"+name, &config)
listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
if err != nil {
return nil, err
}

View File

@@ -13,15 +13,16 @@ import (
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
"golang.org/x/sys/windows"
)
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32
@@ -55,7 +56,7 @@ func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{}
var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle
var ioCompletionPort windows.Handle
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
@@ -65,12 +66,12 @@ type ioResult struct {
// ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct {
o syscall.Overlapped
o windows.Overlapped
ch chan ioResult
}
func initIo() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
if err != nil {
panic(err)
}
@@ -81,7 +82,7 @@ func initIo() {
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct {
handle syscall.Handle
handle windows.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomicBool
@@ -99,7 +100,7 @@ type deadlineHandler struct {
}
// makeWin32File makes a new win32File from an existing file handle
func makeWin32File(h syscall.Handle) (*win32File, error) {
func makeWin32File(h windows.Handle) (*win32File, error) {
f := &win32File{handle: h}
ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
@@ -115,7 +116,7 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
return f, nil
}
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h)
}
@@ -129,7 +130,7 @@ func (f *win32File) closeHandle() {
cancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
syscall.Close(f.handle)
windows.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
@@ -158,12 +159,12 @@ func (f *win32File) prepareIo() (*ioOperation, error) {
}
// ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h syscall.Handle) {
func ioCompletionProcessor(h windows.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
if op == nil {
panic(err)
}
@@ -174,7 +175,7 @@ func ioCompletionProcessor(h syscall.Handle) {
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING {
if err != windows.ERROR_IO_PENDING {
return int(bytes), err
}
@@ -193,7 +194,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
select {
case r = <-c.ch:
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.isSet() {
err = ErrFileClosed
}
@@ -206,7 +207,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
cancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
if err == windows.ERROR_OPERATION_ABORTED {
err = ErrTimeout
}
}
@@ -231,14 +232,14 @@ func (f *win32File) Read(b []byte) (int, error) {
}
var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)
// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE {
} else if err == windows.ERROR_BROKEN_PIPE {
return 0, io.EOF
} else {
return n, err
@@ -258,7 +259,7 @@ func (f *win32File) Write(b []byte) (int, error) {
}
var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
@@ -273,7 +274,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
}
func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle)
return windows.FlushFileBuffers(f.handle)
}
func (f *win32File) Fd() uintptr {

View File

@@ -6,4 +6,4 @@
package winpipe
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go

View File

@@ -16,18 +16,19 @@ import (
"net"
"os"
"runtime"
"syscall"
"time"
"unsafe"
"golang.org/x/sys/windows"
)
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
@@ -41,7 +42,7 @@ type objectAttributes struct {
RootDirectory uintptr
ObjectName *unicodeString
Attributes uintptr
SecurityDescriptor *securityDescriptor
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
SecurityQoS uintptr
}
@@ -51,16 +52,6 @@ type unicodeString struct {
Buffer uintptr
}
type securityDescriptor struct {
Revision byte
Sbz1 byte
Control uint16
Owner uintptr
Group uintptr
Sacl uintptr
Dacl uintptr
}
type ntstatus int32
func (status ntstatus) Err() error {
@@ -71,11 +62,6 @@ func (status ntstatus) Err() error {
}
const (
cERROR_PIPE_BUSY = syscall.Errno(231)
cERROR_NO_DATA = syscall.Errno(232)
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
@@ -88,8 +74,6 @@ const (
cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
)
var (
@@ -170,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
// zero-byte message, ensure that all future Read() calls
// also return EOF.
f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA {
} else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
@@ -188,17 +172,17 @@ func (s pipeAddress) String() string {
}
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for {
select {
case <-ctx.Done():
return syscall.Handle(0), ctx.Err()
return windows.Handle(0), ctx.Err()
default:
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != cERROR_PIPE_BUSY {
if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
@@ -211,7 +195,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
@@ -219,7 +203,7 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
absTimeout = time.Now().Add(time.Second * 2)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialPipeContext(ctx, path)
conn, err := DialPipeContext(ctx, path, expectedOwner)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
@@ -228,23 +212,41 @@ func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
var err error
var h syscall.Handle
var h windows.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
if expectedOwner != nil {
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
if err != nil {
windows.Close(h)
return nil, err
}
realOwner, _, err := sd.Owner()
if err != nil {
windows.Close(h)
return nil, err
}
if !realOwner.Equals(expectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
}
}
var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
windows.Close(h)
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
syscall.Close(h)
windows.Close(h)
return nil, err
}
@@ -264,7 +266,7 @@ type acceptResponse struct {
}
type win32PipeListener struct {
firstHandle syscall.Handle
firstHandle windows.Handle
path string
config PipeConfig
acceptCh chan (chan acceptResponse)
@@ -272,8 +274,8 @@ type win32PipeListener struct {
doneCh chan int
}
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
path16, err := syscall.UTF16FromString(path)
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
path16, err := windows.UTF16FromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
@@ -285,31 +287,32 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer localFree(ntPath.Buffer)
defer windows.LocalFree(windows.Handle(ntPath.Buffer))
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if first {
if sd != nil {
len := uint32(len(sd))
sdb := localAlloc(0, len)
defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
oa.SecurityDescriptor = sd
} else {
// Construct the default named pipe security descriptor.
var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
}
defer localFree(dacl)
sdb := &securityDescriptor{
Revision: 1,
Control: cSE_DACL_PRESENT,
Dacl: dacl,
defer windows.LocalFree(windows.Handle(dacl))
sd, err := windows.NewSecurityDescriptor()
if err != nil {
return 0, fmt.Errorf("creating new security descriptor: %s", err)
}
oa.SecurityDescriptor = sdb
if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
return 0, fmt.Errorf("assigning dacl: %s", err)
}
sd, err = sd.ToSelfRelative()
if err != nil {
return 0, fmt.Errorf("converting to self-relative: %s", err)
}
oa.SecurityDescriptor = sd
}
}
@@ -319,22 +322,22 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
}
disposition := uint32(cFILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if first {
disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false.
access = syscall.SYNCHRONIZE
access = windows.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h syscall.Handle
h windows.Handle
iosb ioStatusBlock
)
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
@@ -350,7 +353,7 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
}
f, err := makeWin32File(h)
if err != nil {
syscall.Close(h)
windows.Close(h)
return nil, err
}
return f, nil
@@ -401,7 +404,7 @@ func (l *win32PipeListener) listenerRoutine() {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != cERROR_NO_DATA {
if err != windows.ERROR_NO_DATA {
break
}
}
@@ -409,7 +412,7 @@ func (l *win32PipeListener) listenerRoutine() {
closed = err == ErrPipeListenerClosed
}
}
syscall.Close(l.firstHandle)
windows.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh)
@@ -417,8 +420,8 @@ func (l *win32PipeListener) listenerRoutine() {
// PipeConfig contain configuration for the pipe listener.
type PipeConfig struct {
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
SecurityDescriptor string
// SecurityDescriptor contains a Windows security descriptor.
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
@@ -438,20 +441,10 @@ type PipeConfig struct {
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
// The pipe must not already exist.
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
var (
sd []byte
err error
)
if c == nil {
c = &PipeConfig{}
}
if c.SecurityDescriptor != "" {
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
if err != nil {
return nil, err
}
}
h, err := makeServerPipeHandle(path, sd, c, true)
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
if err != nil {
return nil, err
}
@@ -476,7 +469,7 @@ func connectPipe(p *win32File) error {
err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED {
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err
}
return nil

View File

@@ -1,29 +0,0 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"unsafe"
)
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
if err != nil {
return nil, err
}
defer localFree(sdBuffer)
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
}

View File

@@ -39,7 +39,6 @@ func errnoErr(e syscall.Errno) error {
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
@@ -52,9 +51,6 @@ var (
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procLocalFree = modkernel32.NewProc("LocalFree")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
@@ -62,7 +58,7 @@ var (
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
)
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
@@ -74,7 +70,7 @@ func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
return
}
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
@@ -83,10 +79,10 @@ func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances ui
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
}
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
handle = windows.Handle(r0)
if handle == windows.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
@@ -96,7 +92,7 @@ func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances
return
}
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
@@ -105,10 +101,10 @@ func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAtt
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
}
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
handle = windows.Handle(r0)
if handle == windows.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
@@ -118,7 +114,7 @@ func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityA
return
}
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 {
if e1 != 0 {
@@ -130,7 +126,7 @@ func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSiz
return
}
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 {
if e1 != 0 {
@@ -148,7 +144,7 @@ func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
return
}
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0)
return
@@ -174,39 +170,7 @@ func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
return
}
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(str)
if err != nil {
return
}
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
}
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
@@ -218,9 +182,9 @@ func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
return
}
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0)
newport = windows.Handle(r0)
if newport == 0 {
if e1 != 0 {
err = errnoErr(e1)
@@ -231,7 +195,7 @@ func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintpt
return
}
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 {
if e1 != 0 {
@@ -243,7 +207,7 @@ func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr,
return
}
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
@@ -255,7 +219,7 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
return
}
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32
if wait {
_p0 = 1

16
main.go
View File

@@ -40,31 +40,19 @@ func warning() {
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return
}
shouldQuit := os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and foolish. This G")
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit {
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W the advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
}
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
if shouldQuit {
os.Exit(1)
}
}
func main() {
@@ -75,8 +63,6 @@ func main() {
warning()
// parse arguments
var foreground bool
var interfaceName string
if len(os.Args) < 2 || len(os.Args) > 3 {

View File

@@ -37,7 +37,7 @@ func main() {
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {
realInterfaceName, err2 := tun.Name()
if err2 == nil {

View File

@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{
allowed: true,
text: "inital burst",
text: "initial burst",
})
}

View File

@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
fdset := fdSet{}
fdset.set(rw.fd)
fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil {
return false
}
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
fdset := fdSet{}
fdset.set(rw.fd)
fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil {
return false
}

View File

@@ -11,6 +11,7 @@ import (
"net"
"os"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv6"
@@ -42,6 +43,22 @@ type NativeTun struct {
var sockaddrCtlSize uintptr = 32
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
}
}
return iface, err
}
return nil, err
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
iface, err := retryInterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return

View File

@@ -35,7 +35,7 @@ type NativeTun struct {
name string // name of interface
errors chan error // async error handling
events chan Event // device related events
nopi bool // the device was pased IFF_NO_PI
nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
@@ -85,7 +85,7 @@ func createNetlinkSocket() (int, error) {
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {

View File

@@ -42,34 +42,11 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
defer close(tun.events)
data := make([]byte, os.Getpagesize())
for {
retry:
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
goto retry
}
tun.errors <- err
return
}
if n < 8 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
check := func() bool {
iface, err := net.InterfaceByIndex(tunIfindex)
if err != nil {
tun.errors <- err
return
return true
}
// Up / Down event
@@ -87,6 +64,38 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
tun.events <- EventMTUUpdate
}
statusMTU = iface.MTU
return false
}
if check() {
return
}
data := make([]byte, os.Getpagesize())
for {
n, err := unix.Read(tun.routeSocket, data)
if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
continue
}
tun.errors <- err
return
}
if n < 8 {
continue
}
if data[3 /* type */] != unix.RTM_IFINFO {
continue
}
ifindex := int(*(*uint16)(unsafe.Pointer(&data[6 /* ifindex */])))
if ifindex != tunIfindex {
continue
}
if check() {
return
}
}
}
@@ -140,7 +149,6 @@ func CreateTUN(name string, mtu int) (Device, error) {
}
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 10),

View File

@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
@@ -19,87 +20,71 @@ import (
)
const (
packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings
packetSizeMax = 0xffff // Maximum packet size
packetCapacity = 0x800000 // Ring capacity, 8MiB
packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment
ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
spinloopRateThreshold = 800000000 / 8 // 800mbps
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
)
type packetHeader struct {
size uint32
}
type packet struct {
packetHeader
data [packetSizeMax]byte
}
type ring struct {
head uint32
tail uint32
alertable int32
data [packetCapacity + packetTrailingSize]byte
}
type ringDescriptor struct {
send, receive struct {
size uint32
ring *ring
tailMoved windows.Handle
}
type rateJuggler struct {
current uint64
nextByteCount uint64
nextStartTime int64
changing int32
}
type NativeTun struct {
wt *wintun.Wintun
wt *wintun.Interface
handle windows.Handle
close bool
rings ringDescriptor
events chan Event
errors chan error
forcedMTU int
rate rateJuggler
rings *wintun.RingDescriptor
writeLock sync.Mutex
}
func packetAlign(size uint32) uint32 {
return (size + (packetAlignment - 1)) &^ (packetAlignment - 1)
const WintunPool = wintun.Pool("WireGuard")
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
//go:linkname nanotime runtime.nanotime
func nanotime() int64
//
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
//
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil, mtu)
}
//
// CreateTUN creates a Wintun adapter with the given name. Should a Wintun
// adapter with the same name exist, it is reused.
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUN(ifname string) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil)
}
//
// CreateTUNWithRequestedGUID creates a Wintun adapter with the given name and
// a requested GUID. Should a Wintun adapter with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) {
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error
var wt *wintun.Wintun
var wt *wintun.Interface
// Does an interface with this name already exist?
wt, err = wintun.GetInterface(ifname)
wt, err = WintunPool.GetInterface(ifname)
if err == nil {
// If so, we delete it, in case it has weird residual configuration.
_, err = wt.DeleteInterface()
if err != nil {
return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err)
return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
}
} else if err == windows.ERROR_ALREADY_EXISTS {
return nil, fmt.Errorf("Foreign network interface with the same name exists")
}
wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", requestedGUID)
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
if err != nil {
return nil, fmt.Errorf("Unable to create Wintun interface: %v", err)
return nil, fmt.Errorf("Error creating interface: %v", err)
}
err = wt.SetInterfaceName(ifname)
if err != nil {
wt.DeleteInterface()
return nil, fmt.Errorf("Unable to set name of Wintun interface: %v", err)
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &NativeTun{
@@ -107,33 +92,16 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
handle: windows.InvalidHandle,
events: make(chan Event, 10),
errors: make(chan error, 1),
forcedMTU: 1500,
forcedMTU: forcedMTU,
}
tun.rings.send.size = uint32(unsafe.Sizeof(ring{}))
tun.rings.send.ring = &ring{}
tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
tun.rings, err = wintun.NewRingDescriptor()
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating event: %v", err)
return nil, fmt.Errorf("Error creating events: %v", err)
}
tun.rings.receive.size = uint32(unsafe.Sizeof(ring{}))
tun.rings.receive.ring = &ring{}
tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating event: %v", err)
}
tun.handle, err = tun.wt.AdapterHandle()
if err != nil {
tun.Close()
return nil, err
}
var bytesReturned uint32
err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil)
tun.handle, err = tun.wt.Register(tun.rings)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -142,7 +110,7 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
}
func (tun *NativeTun) Name() (string, error) {
return tun.wt.InterfaceName()
return tun.wt.Name()
}
func (tun *NativeTun) File() *os.File {
@@ -155,18 +123,13 @@ func (tun *NativeTun) Events() chan Event {
func (tun *NativeTun) Close() error {
tun.close = true
if tun.rings.send.tailMoved != 0 {
windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping
if tun.rings.Send.TailMoved != 0 {
windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
}
if tun.handle != windows.InvalidHandle {
windows.CloseHandle(tun.handle)
}
if tun.rings.send.tailMoved != 0 {
windows.CloseHandle(tun.rings.send.tailMoved)
}
if tun.rings.send.tailMoved != 0 {
windows.CloseHandle(tun.rings.receive.tailMoved)
}
tun.rings.Close()
var err error
if tun.wt != nil {
_, err = tun.wt.DeleteInterface()
@@ -184,9 +147,6 @@ func (tun *NativeTun) ForceMTU(mtu int) {
tun.forcedMTU = mtu
}
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
@@ -200,50 +160,52 @@ retry:
return 0, os.ErrClosed
}
buffHead := atomic.LoadUint32(&tun.rings.send.ring.head)
if buffHead >= packetCapacity {
buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed
}
start := time.Now()
start := nanotime()
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
var buffTail uint32
for {
buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail)
buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
if buffHead != buffTail {
break
}
if tun.close {
return 0, os.ErrClosed
}
if time.Since(start) >= time.Millisecond/80 /* ~1gbit/s */ {
windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE)
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
goto retry
}
procyield(1)
}
if buffTail >= packetCapacity {
if buffTail >= wintun.PacketCapacity {
return 0, os.ErrClosed
}
buffContent := tun.rings.send.ring.wrap(buffTail - buffHead)
if buffContent < uint32(unsafe.Sizeof(packetHeader{})) {
buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
return 0, errors.New("incomplete packet header in send ring")
}
packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead]))
if packet.size > packetSizeMax {
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
if packet.Size > wintun.PacketSizeMax {
return 0, errors.New("packet too big in send ring")
}
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size)
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
if alignedPacketSize > buffContent {
return 0, errors.New("incomplete packet in send ring")
}
copy(buff[offset:], packet.data[:packet.size])
buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize)
atomic.StoreUint32(&tun.rings.send.ring.head, buffHead)
return int(packet.size), nil
copy(buff[offset:], packet.Data[:packet.Size])
buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
tun.rate.update(uint64(packet.Size))
return int(packet.Size), nil
}
func (tun *NativeTun) Flush() error {
@@ -256,39 +218,58 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
}
packetSize := uint32(len(buff) - offset)
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize)
tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head)
if buffHead >= packetCapacity {
tun.writeLock.Lock()
defer tun.writeLock.Unlock()
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed
}
buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail)
if buffTail >= packetCapacity {
buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
if buffTail >= wintun.PacketCapacity {
return 0, os.ErrClosed
}
buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment)
buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
if alignedPacketSize > buffSpace {
return 0, nil // Dropping when ring is full.
}
packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail]))
packet.size = packetSize
copy(packet.data[:packetSize], buff[offset:])
atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize))
if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 {
windows.SetEvent(tun.rings.receive.tailMoved)
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
packet.Size = packetSize
copy(packet.Data[:packetSize], buff[offset:])
atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
windows.SetEvent(tun.rings.Receive.TailMoved)
}
return int(packetSize), nil
}
// LUID returns Windows adapter instance ID.
// LUID returns Windows interface instance ID.
func (tun *NativeTun) LUID() uint64 {
return tun.wt.LUID()
}
// wrap returns value modulo ring capacity
func (rb *ring) wrap(value uint32) uint32 {
return value & (packetCapacity - 1)
// Version returns the version of the Wintun driver and NDIS system currently loaded.
func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
return tun.wt.Version()
}
func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
if period >= rateMeasurementGranularity {
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
return
}
atomic.StoreInt64(&rate.nextStartTime, now)
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
atomic.StoreUint64(&rate.nextByteCount, 0)
atomic.StoreInt32(&rate.changing, 0)
}
}

View File

@@ -0,0 +1,25 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
import "golang.org/x/sys/windows"
//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
var luid uint64
var guid windows.GUID
err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
if err != nil {
return nil, err
}
err = convertInterfaceLUIDToGUID(&luid, &guid)
if err != nil {
return nil, err
}
return &guid, nil
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package iphlpapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go

View File

@@ -0,0 +1,60 @@
// Code generated by 'go generate'; DO NOT EDIT.
package iphlpapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
)
func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@@ -0,0 +1,98 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"encoding/hex"
"errors"
"fmt"
"sync"
"unsafe"
"golang.org/x/crypto/blake2s"
"golang.org/x/sys/windows"
"golang.org/x/text/unicode/norm"
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
)
var (
wintunObjectSecurityAttributes *windows.SecurityAttributes
hasInitializedNamespace bool
initializingNamespace sync.Mutex
)
func initializeNamespace() error {
initializingNamespace.Lock()
defer initializingNamespace.Unlock()
if hasInitializedNamespace {
return nil
}
sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil {
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
}
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: sd,
}
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil {
return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
}
boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
if err != nil {
return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
}
err = boundary.AddSid(sid)
if err != nil {
return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
}
for {
_, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
if err == windows.ERROR_ALREADY_EXISTS {
_, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
if err == windows.ERROR_PATH_NOT_FOUND {
continue
}
}
if err != nil {
return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
}
break
}
hasInitializedNamespace = true
return nil
}
func (pool Pool) takeNameMutex() (windows.Handle, error) {
err := initializeNamespace()
if err != nil {
return 0, err
}
const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
b2, _ := blake2s.New256(nil)
b2.Write([]byte(mutexLabel))
b2.Write(norm.NFC.Bytes([]byte(string(pool))))
mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
if err != nil {
err = fmt.Errorf("Error creating name mutex: %v", err)
return 0, err
}
event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
if err != nil {
windows.CloseHandle(mutex)
return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
}
if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
windows.CloseHandle(mutex)
return 0, errors.New("Error with event trigger of name mutex")
}
return mutex, nil
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go

View File

@@ -0,0 +1,83 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package namespaceapi
import "golang.org/x/sys/windows"
//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
type BoundaryDescriptor windows.Handle
// CreateBoundaryDescriptor creates a boundary descriptor.
func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
name16, err := windows.UTF16PtrFromString(name)
if err != nil {
return 0, err
}
handle, err := createBoundaryDescriptor(name16, 0)
if err != nil {
return 0, err
}
return BoundaryDescriptor(handle), nil
}
// Delete deletes the specified boundary descriptor.
func (bd BoundaryDescriptor) Delete() {
deleteBoundaryDescriptor(windows.Handle(bd))
}
// AddSid adds a security identifier (SID) to the specified boundary descriptor.
func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
}
// PrivateNamespace represents a private namespace.
type PrivateNamespace windows.Handle
// CreatePrivateNamespace creates a private namespace.
func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
if err != nil {
return 0, err
}
handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
if err != nil {
return 0, err
}
return PrivateNamespace(handle), nil
}
// OpenPrivateNamespace opens a private namespace.
func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
if err != nil {
return 0, err
}
handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
if err != nil {
return 0, err
}
return PrivateNamespace(handle), nil
}
// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
type ClosePrivateNamespaceFlags uint32
const (
// PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
)
// Close closes an open namespace handle.
func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
return closePrivateNamespace(windows.Handle(pns), uint32(flags))
}

View File

@@ -0,0 +1,116 @@
// Code generated by 'go generate'; DO NOT EDIT.
package namespaceapi
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
)
func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
return
}
func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
handle = windows.Handle(r0)
if handle == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package nci
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go

View File

@@ -0,0 +1,28 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package nci
import "golang.org/x/sys/windows"
//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName
//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName
func SetConnectionName(guid *windows.GUID, newName string) error {
newName16, err := windows.UTF16PtrFromString(newName)
if err != nil {
return err
}
return nciSetConnectionName(guid, newName16)
}
func ConnectionName(guid *windows.GUID) (string, error) {
var name [0x400]uint16
err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil)
if err != nil {
return "", err
}
return windows.UTF16ToString(name[:]), nil
}

View File

@@ -0,0 +1,60 @@
// Code generated by 'go generate'; DO NOT EDIT.
package nci
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modnci = windows.NewLazySystemDLL("nci.dll")
procNciSetConnectionName = modnci.NewProc("NciSetConnectionName")
procNciGetConnectionName = modnci.NewProc("NciGetConnectionName")
)
func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) {
r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) {
r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}

View File

@@ -1,32 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package netshell
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var (
modnetshell = windows.NewLazySystemDLL("netshell.dll")
procHrRenameConnection = modnetshell.NewProc("HrRenameConnection")
)
func HrRenameConnection(guid *windows.GUID, newName *uint16) (err error) {
err = procHrRenameConnection.Find()
if err != nil {
// Missing from servercore, so we can't presume it's always there.
return
}
ret, _, _ := syscall.Syscall(procHrRenameConnection.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
if ret != 0 {
err = syscall.Errno(ret)
}
return
}

View File

@@ -5,4 +5,4 @@
package registry
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go

117
tun/wintun/ring_windows.go Normal file
View File

@@ -0,0 +1,117 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"runtime"
"unsafe"
"golang.org/x/sys/windows"
)
const (
PacketAlignment = 4 // Number of bytes packets are aligned to in rings
PacketSizeMax = 0xffff // Maximum packet size
PacketCapacity = 0x800000 // Ring capacity, 8MiB
PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
)
type PacketHeader struct {
Size uint32
}
type Packet struct {
PacketHeader
Data [PacketSizeMax]byte
}
type Ring struct {
Head uint32
Tail uint32
Alertable int32
Data [PacketCapacity + PacketTrailingSize]byte
}
type RingDescriptor struct {
Send, Receive struct {
Size uint32
Ring *Ring
TailMoved windows.Handle
}
}
// Wrap returns value modulo ring capacity
func (rb *Ring) Wrap(value uint32) uint32 {
return value & (PacketCapacity - 1)
}
// Aligns a packet size to PacketAlignment
func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
}
func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
descriptor = new(RingDescriptor)
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return
}
defer func() {
if err != nil {
descriptor.free()
descriptor = nil
}
}()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return
}
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved)
return
}
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return
}
func (descriptor *RingDescriptor) free() {
if descriptor.Send.Ring != nil {
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
descriptor.Send.Ring = nil
descriptor.Receive.Ring = nil
}
}
func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved)
descriptor.Send.TailMoved = 0
}
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Receive.TailMoved)
descriptor.Receive.TailMoved = 0
}
}
func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) {
handle, err := wintun.handle()
if err != nil {
return 0, err
}
var bytesReturned uint32
err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
if err != nil {
return 0, err
}
return handle, nil
}

View File

@@ -5,4 +5,4 @@
package setupapi
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsetupapi_windows.go setupapi_windows.go
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go

View File

@@ -57,7 +57,7 @@ type DevInfoData struct {
_ uintptr
}
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct {
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID

View File

@@ -15,17 +15,20 @@ import (
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/tun/wintun/netshell"
"golang.zx2c4.com/wireguard/tun/wintun/iphlpapi"
"golang.zx2c4.com/wireguard/tun/wintun/nci"
registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry"
"golang.zx2c4.com/wireguard/tun/wintun/setupapi"
)
// Wintun is a handle of a Wintun adapter.
type Wintun struct {
type Pool string
type Interface struct {
cfgInstanceID windows.GUID
devInstanceID string
luidIndex uint32
ifType uint32
pool Pool
}
var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
@@ -37,9 +40,9 @@ const (
)
// makeWintun creates a Wintun interface handle and populates it from the device's registry key.
func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (*Wintun, error) {
func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
}
@@ -69,30 +72,48 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
}
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData)
instanceID, err := devInfo.DeviceInstanceID(devInfoData)
if err != nil {
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
}
return &Wintun{
return &Interface{
cfgInstanceID: ifid,
devInstanceID: instanceID,
luidIndex: uint32(luidIdx),
ifType: uint32(ifType),
pool: pool,
}, nil
}
func removeNumberedSuffix(ifname string) string {
removed := strings.TrimRight(ifname, "0123456789")
if removed != ifname && len(removed) > 1 && removed[len(removed)-1] == ' ' {
return removed[:len(removed)-1]
}
return ifname
}
// GetInterface finds a Wintun interface by its name. This function returns
// the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If
// the interface is found but not a Wintun-class, this function returns
// windows.ERROR_ALREADY_EXISTS.
func GetInterface(ifname string) (*Wintun, error) {
// the interface is found but not a Wintun-class or a member of the pool,
// this function returns windows.ERROR_ALREADY_EXISTS.
func (pool Pool) GetInterface(ifname string) (*Interface, error) {
mutex, err := pool.takeNameMutex()
if err != nil {
return nil, err
}
defer func() {
windows.ReleaseMutex(mutex)
windows.CloseHandle(mutex)
}()
// Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
}
defer devInfoList.Close()
defer devInfo.Close()
// Windows requires each interface to have a different name. When
// enforcing this, Windows treats interface names case-insensitive. If an
@@ -102,7 +123,7 @@ func GetInterface(ifname string) (*Wintun, error) {
ifname = strings.ToLower(ifname)
for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index)
devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -110,26 +131,37 @@ func GetInterface(ifname string) (*Wintun, error) {
continue
}
wintun, err := makeWintun(devInfoList, deviceData)
// Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil {
continue
}
// TODO: is there a better way than comparing ifnames?
ifname2, err := wintun.InterfaceName()
ifname2, err := wintun.Name()
if err != nil {
continue
}
ifname2 = strings.ToLower(ifname2)
ifname3 := removeNumberedSuffix(ifname2)
if ifname == strings.ToLower(ifname2) {
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
if ifname == ifname2 || ifname == ifname3 {
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
for index := 0; ; index++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -138,12 +170,20 @@ func GetInterface(ifname string) (*Wintun, error) {
}
// Get driver info details.
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil {
continue
}
if driverDetailData.IsCompatible(hardwareID) {
isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil {
return nil, err
}
if !isMember {
return nil, windows.ERROR_ALREADY_EXISTS
}
return wintun, nil
}
}
@@ -156,24 +196,31 @@ func GetInterface(ifname string) (*Wintun, error) {
return nil, windows.ERROR_OBJECT_NOT_FOUND
}
// CreateInterface creates a Wintun interface. description is a string that
// supplies the text description of the device. The description is optional
// and can be "". requestedGUID is the GUID of the created network interface,
// which then influences NLA generation deterministically. If it is set to nil,
// the GUID is chosen by the system at random, and hence a new NLA entry is
// created for each new interface. It is called "requested" GUID because the
// API it uses is completely undocumented, and so there could be minor
// CreateInterface creates a Wintun interface. ifname is the requested name of
// the interface, while requestedGUID is the GUID of the created network
// interface, which then influences NLA generation deterministically. If it is
// set to nil, the GUID is chosen by the system at random, and hence a new NLA
// entry is created for each new interface. It is called "requested" GUID
// because the API it uses is completely undocumented, and so there could be minor
// interesting complications with its usage. This function returns the network
// interface ID and a flag if reboot is required.
//
func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *Wintun, rebootRequired bool, err error) {
func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wintun *Interface, rebootRequired bool, err error) {
mutex, err := pool.takeNameMutex()
if err != nil {
return
}
defer func() {
windows.ReleaseMutex(mutex)
windows.CloseHandle(mutex)
}()
// Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
return
}
defer devInfoList.Close()
defer devInfo.Close()
// Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
@@ -183,43 +230,44 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
}
// Create a new device info element and add it to the device info set.
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, 0, setupapi.DICD_GENERATE_ID)
deviceTypeName := pool.deviceTypeName()
devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
return
}
err = setQuietInstall(devInfoList, deviceData)
err = setQuietInstall(devInfo, devInfoData)
if err != nil {
err = fmt.Errorf("Setting quiet installation failed: %v", err)
return
}
// Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData)
err = devInfo.SetSelectedDevice(devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
return
}
// Set Plug&Play device hardware ID property.
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID)
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil {
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
return
}
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
if err != nil {
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
return
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
driverDate := windows.Filetime{}
driverVersion := uint64(0)
for index := 0; ; index++ { // TODO: This loop takes ~600ms
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -229,13 +277,13 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
if driverData.IsNewer(driverDate, driverVersion) {
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil {
continue
}
if driverDetailData.IsCompatible(hardwareID) {
err := devInfoList.SetSelectedDriver(deviceData, driverData)
err := devInfo.SetSelectedDriver(devInfoData, driverData)
if err != nil {
continue
}
@@ -251,42 +299,6 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
return
}
// Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return
}
// Register device co-installers if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData)
var key registry.Key
const pollTimeout = time.Millisecond * 50
for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
if i != 0 {
time.Sleep(pollTimeout)
}
key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil {
break
}
}
if err != nil {
err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
return
}
defer key.Close()
if requestedGUID != nil {
err = key.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
if err != nil {
err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
return
}
}
// Install interfaces if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
defer func() {
if err != nil {
// The interface failed to install, or the interface ID was unobtainable. Clean-up.
@@ -296,10 +308,10 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
}
// Set class installer parameters for DIF_REMOVE.
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
// Call appropriate class installer.
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
}
}
@@ -307,64 +319,85 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
}
}()
// Call appropriate class installer.
err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return
}
// Register device co-installers if any. (Ignore errors)
devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
var netDevRegKey registry.Key
const pollTimeout = time.Millisecond * 50
for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
if i != 0 {
time.Sleep(pollTimeout)
}
netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil {
break
}
}
if err != nil {
err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
return
}
defer netDevRegKey.Close()
if requestedGUID != nil {
err = netDevRegKey.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
if err != nil {
err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
return
}
}
// Install interfaces if any. (Ignore errors)
devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
// Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
return
}
rebootRequired = checkReboot(devInfoList, deviceData)
rebootRequired = checkReboot(devInfo, devInfoData)
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
if err != nil {
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
return
}
// DIF_INSTALLDEVICE returns almost immediately, while the device installation
// continues in the background. It might take a while, before all registry
// keys and values are populated.
_, err = registryEx.GetStringValueWait(key, "NetCfgInstanceId", waitForRegistryTimeout)
_, err = registryEx.GetStringValueWait(netDevRegKey, "NetCfgInstanceId", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err)
return
}
_, err = registryEx.GetIntegerValueWait(key, "NetLuidIndex", waitForRegistryTimeout)
_, err = registryEx.GetIntegerValueWait(netDevRegKey, "NetLuidIndex", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err)
return
}
_, err = registryEx.GetIntegerValueWait(key, "*IfType", waitForRegistryTimeout)
_, err = registryEx.GetIntegerValueWait(netDevRegKey, "*IfType", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err)
return
}
// Get network interface.
wintun, err = makeWintun(devInfoList, deviceData)
wintun, err = makeWintun(devInfo, devInfoData, pool)
if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err)
return
}
// Wait for network registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.netRegKeyName(),
registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err)
return
}
defer key.Close()
_, err = registryEx.GetStringValueWait(key, "Name", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(Name) failed: %v", err)
return
}
_, err = registryEx.GetStringValueWait(key, "PnPInstanceId", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(PnPInstanceId) failed: %v", err)
return
}
// Wait for TCP/IP adapter registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
tcpipAdapterRegKey, err := registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
@@ -372,8 +405,8 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err)
return
}
defer key.Close()
_, err = registryEx.GetStringValueWait(key, "IpConfig", waitForRegistryTimeout)
defer tcpipAdapterRegKey.Close()
_, err = registryEx.GetStringValueWait(tcpipAdapterRegKey, "IpConfig", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err)
return
@@ -386,28 +419,23 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
}
// Wait for TCP/IP interface registry key to emerge.
key, err = registryEx.OpenKeyWait(
tcpipInterfaceRegKey, err := registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
tcpipInterfaceRegKeyName, registry.QUERY_VALUE,
tcpipInterfaceRegKeyName, registry.QUERY_VALUE|registry.SET_VALUE,
waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err)
return
}
key.Close()
//
// All the registry keys and values we're relying on are present now.
//
defer tcpipInterfaceRegKey.Close()
// Disable dead gateway detection on our interface.
key, err = registry.OpenKey(registry.LOCAL_MACHINE, tcpipInterfaceRegKeyName, registry.SET_VALUE)
tcpipInterfaceRegKey.SetDWordValue("EnableDeadGWDetect", 0)
err = wintun.SetName(ifname)
if err != nil {
err = fmt.Errorf("Error opening interface-specific TCP/IP network registry key: %v", err)
err = fmt.Errorf("Unable to set name of Wintun interface: %v", err)
return
}
key.SetDWordValue("EnableDeadGWDetect", 0)
key.Close()
return
}
@@ -415,15 +443,15 @@ func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *W
// DeleteInterface deletes a Wintun interface. This function succeeds
// if the interface was not found. It returns a bool indicating whether
// a reboot is required.
func (wintun *Wintun) DeleteInterface() (rebootRequired bool, err error) {
devInfoList, deviceData, err := wintun.deviceData()
func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
devInfo, devInfoData, err := wintun.devInfoData()
if err == windows.ERROR_OBJECT_NOT_FOUND {
return false, nil
}
if err != nil {
return false, err
}
defer devInfoList.Close()
defer devInfo.Close()
// Remove the device.
removeDeviceParams := setupapi.RemoveDeviceParams{
@@ -432,32 +460,42 @@ func (wintun *Wintun) DeleteInterface() (rebootRequired bool, err error) {
}
// Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil {
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
}
// Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil {
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
}
return checkReboot(devInfoList, deviceData), nil
return checkReboot(devInfo, devInfoData), nil
}
// DeleteAllInterfaces deletes all Wintun interfaces, and returns which
// ones it deleted, whether a reboot is required after, and which errors
// occurred during the process.
func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) {
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
// DeleteMatchingInterfaces deletes all Wintun interfaces, which match
// given criteria, and returns which ones it deleted, whether a reboot
// is required after, and which errors occurred during the process.
func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool) (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) {
mutex, err := pool.takeNameMutex()
if err != nil {
errors = append(errors, err)
return
}
defer func() {
windows.ReleaseMutex(mutex)
windows.CloseHandle(mutex)
}()
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
}
defer devInfoList.Close()
defer devInfo.Close()
for i := 0; ; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i)
devInfoData, err := devInfo.EnumDeviceInfo(i)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -465,22 +503,31 @@ func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool
continue
}
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
// Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
continue
}
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
isWintun := false
for j := 0; ; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j)
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
}
continue
}
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil {
continue
}
@@ -493,36 +540,71 @@ func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool
continue
}
err = setQuietInstall(devInfoList, deviceData)
isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil {
errors = append(errors, err)
continue
}
if !isMember {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil {
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
continue
}
if !matches(wintun) {
continue
}
err = setQuietInstall(devInfo, devInfoData)
if err != nil {
errors = append(errors, err)
continue
}
inst := deviceData.DevInst
inst := devInfoData.DevInst
removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
}
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil {
errors = append(errors, err)
continue
}
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil {
errors = append(errors, err)
continue
}
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
}
return
}
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
}
deviceDesc, _ := deviceDescVal.(string)
friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
}
friendlyName, _ := friendlyNameVal.(string)
deviceTypeName := pool.deviceTypeName()
return friendlyName == deviceTypeName || deviceDesc == deviceTypeName ||
removeNumberedSuffix(friendlyName) == deviceTypeName || removeNumberedSuffix(deviceDesc) == deviceTypeName, nil
}
// checkReboot checks device install parameters if a system reboot is required.
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil {
return false
}
@@ -531,57 +613,117 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
}
// setQuietInstall sets device install parameters for a quiet installation
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil {
return err
}
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams)
return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
}
// InterfaceName returns the name of the Wintun interface.
func (wintun *Wintun) InterfaceName() (string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.QUERY_VALUE)
// deviceTypeName returns pool-specific device type name.
func (pool Pool) deviceTypeName() string {
return fmt.Sprintf("%s Tunnel", pool)
}
// Name returns the name of the Wintun interface.
func (wintun *Interface) Name() (string, error) {
return nci.ConnectionName(&wintun.cfgInstanceID)
}
// SetName sets name of the Wintun interface.
func (wintun *Interface) SetName(ifname string) error {
const maxSuffix = 1000
availableIfname := ifname
for i := 0; ; i++ {
err := nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
if err == windows.ERROR_DUP_NAME {
duplicateGuid, err2 := iphlpapi.InterfaceGUIDFromAlias(availableIfname)
if err2 == nil {
for j := 0; j < maxSuffix; j++ {
proposal := fmt.Sprintf("%s %d", ifname, j+1)
if proposal == availableIfname {
continue
}
err2 = nci.SetConnectionName(duplicateGuid, proposal)
if err2 == windows.ERROR_DUP_NAME {
continue
}
if err2 == nil {
err = nci.SetConnectionName(&wintun.cfgInstanceID, availableIfname)
if err == nil {
break
}
}
break
}
}
}
if err == nil {
break
}
if i > maxSuffix || err != windows.ERROR_DUP_NAME {
return fmt.Errorf("NciSetConnectionName failed: %v", err)
}
availableIfname = fmt.Sprintf("%s %d", ifname, i+1)
}
// TODO: This should use NetSetup2 so that it doesn't get unset.
deviceRegKey, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.deviceRegKeyName(), registry.SET_VALUE)
if err != nil {
return "", fmt.Errorf("Network-specific registry key open failed: %v", err)
return fmt.Errorf("Device-level registry key open failed: %v", err)
}
defer key.Close()
// Get the interface name.
return registryEx.GetStringValue(key, "Name")
}
// SetInterfaceName sets name of the Wintun interface.
func (wintun *Wintun) SetInterfaceName(ifname string) error {
// We have to tell the various runtime COM services about the new name too. We ignore the
// error because netshell isn't available on servercore.
// TODO: netsh.exe falls back to NciSetConnection in this case. If somebody complains, maybe
// we should do the same.
netshell.HrRenameConnection(&wintun.cfgInstanceID, windows.StringToUTF16Ptr(ifname))
// Set the interface name. The above line should have done this too, but in case it failed, we force it.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.SET_VALUE)
defer deviceRegKey.Close()
err = deviceRegKey.SetStringValue("FriendlyName", wintun.pool.deviceTypeName())
if err != nil {
return fmt.Errorf("Network-specific registry key open failed: %v", err)
return fmt.Errorf("SetStringValue(FriendlyName) failed: %v", err)
}
defer key.Close()
return key.SetStringValue("Name", ifname)
}
// netRegKeyName returns the interface-specific network registry key name.
func (wintun *Wintun) netRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", deviceClassNetGUID, wintun.cfgInstanceID)
return nil
}
// tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name.
func (wintun *Wintun) tcpipAdapterRegKeyName() string {
func (wintun *Interface) tcpipAdapterRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
}
// deviceRegKeyName returns the device-level registry key name.
func (wintun *Interface) deviceRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
}
// Version returns the version of the Wintun driver and NDIS system currently loaded.
func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE)
if err != nil {
return
}
defer key.Close()
driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion")
if err != nil {
return
}
driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion")
if err != nil {
return
}
ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion")
if err != nil {
return
}
ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion")
if err != nil {
return
}
driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor)
ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor)
return
}
// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
func (wintun *Wintun) tcpipInterfaceRegKeyName() (path string, err error) {
func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
if err != nil {
return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
@@ -597,18 +739,18 @@ func (wintun *Wintun) tcpipInterfaceRegKeyName() (path string, err error) {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
}
// deviceData returns TUN device info list handle and interface device info
// devInfoData returns TUN device info list handle and interface device info
// data. The device info list handle must be closed after use. In case the
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
func (wintun *Wintun) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
// Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
}
for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index)
devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -618,44 +760,44 @@ func (wintun *Wintun) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, err
// Get interface ID.
// TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData)
wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
if err != nil {
continue
}
if wintun.cfgInstanceID == wintun2.cfgInstanceID {
err = setQuietInstall(devInfoList, deviceData)
err = setQuietInstall(devInfo, devInfoData)
if err != nil {
devInfoList.Close()
devInfo.Close()
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
}
return devInfoList, deviceData, nil
return devInfo, devInfoData, nil
}
}
devInfoList.Close()
devInfo.Close()
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
}
// AdapterHandle returns a handle to the adapter device object.
func (wintun *Wintun) AdapterHandle() (windows.Handle, error) {
// handle returns a handle to the interface device object.
func (wintun *Interface) handle() (windows.Handle, error) {
interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
if err != nil {
return windows.InvalidHandle, err
return windows.InvalidHandle, fmt.Errorf("Error listing NDIS interfaces: %v", err)
}
handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
if err != nil {
return windows.InvalidHandle, fmt.Errorf("Open NDIS device failed: %v", err)
return windows.InvalidHandle, fmt.Errorf("Error opening NDIS device: %v", err)
}
return handle, nil
}
// GUID returns the GUID of the interface.
func (wintun *Wintun) GUID() windows.GUID {
func (wintun *Interface) GUID() windows.GUID {
return wintun.cfgInstanceID
}
// LUID returns the LUID of the interface.
func (wintun *Wintun) LUID() uint64 {
func (wintun *Interface) LUID() uint64 {
return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48)
}