107 Commits

Author SHA1 Message Date
Jason A. Donenfeld
b16dba47a7 version: bump snapshot 2019-08-05 19:29:12 +02:00
Jason A. Donenfeld
4be9630ddc device: drop lock before expiring keys 2019-08-05 17:46:34 +02:00
Jason A. Donenfeld
4e3018a967 uapi: skip peers with invalid keys 2019-08-05 16:57:41 +02:00
Jason A. Donenfeld
b4010123f7 tun: windows: spin for only a millisecond/80
Performance stays the same as before.
2019-08-03 19:11:21 +02:00
Simon Rozman
1ff37e2b07 wintun: merge opening device registry key
This also introduces waiting for key to appear on initial access.

See if this resolves the issue caused by HDD power-up delay resulting in
failure to create the adapter.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-02 16:08:49 +02:00
Simon Rozman
f5e54932e6 wintun: simplify checking reboot requirement
We never checked checkReboot() reported error anyway.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-02 16:08:49 +02:00
Simon Rozman
73698066d1 wintun: refactor err == nil error checking
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-08-02 15:18:58 +02:00
Jason A. Donenfeld
05ece4d167 wintun: handle error for deadgwdetect 2019-08-02 14:37:09 +02:00
Jason A. Donenfeld
6d78f89557 tun: darwin: do not attempt to close tun.event twice
Previously it was possible for this to race. It turns out we really
don't need to set anything to -1 anyway.
2019-08-02 12:24:17 +02:00
Jason A. Donenfeld
a2249449d6 wintun: get interface path properly with cfgmgr 2019-07-23 14:58:46 +02:00
Jason A. Donenfeld
eeeac287ef tun: windows: style 2019-07-23 11:45:48 +02:00
Jason A. Donenfeld
b5a7cbf069 wintun: simplify resolution of dev node 2019-07-23 11:45:13 +02:00
Jason A. Donenfeld
50cd522cb0 wintun: enable sharing of pnp node 2019-07-22 17:01:27 +02:00
Jason A. Donenfeld
5ba866a5c8 tun: windows: close event handle on shutdown 2019-07-22 09:37:20 +02:00
Jason A. Donenfeld
2f101fedec ipc: windows: match SDDL of WDK and make monkeyable 2019-07-19 15:34:26 +02:00
Jason A. Donenfeld
3341e2d444 tun: windows: get rid of retry logic
Things work fine on Windows 8.
2019-07-19 14:01:34 +02:00
Jason A. Donenfeld
1b550f6583 tun: windows: use specific IOCTL code 2019-07-19 08:30:19 +02:00
Jason A. Donenfeld
7bc0e11831 device: do not crash on nil'd bind in windows binding 2019-07-18 19:34:45 +02:00
Jason A. Donenfeld
31ff9c02fe tun: windows: open file at startup time 2019-07-18 19:27:27 +02:00
Jason A. Donenfeld
1e39c33ab1 tun: windows: silently drop packet when ring is full 2019-07-18 15:48:34 +02:00
Jason A. Donenfeld
6c50fedd8e tun: windows: switch to NDIS device object 2019-07-18 12:26:57 +02:00
Jason A. Donenfeld
298d759f3e wintun: calculate path of NDIS device object symbolic link 2019-07-18 10:25:20 +02:00
Michael Zeltner
4d5819183e tun: openbsd: don't change MTU when it's already the expected size
Allows for running wireguard-go as non-root user.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2019-07-18 10:25:20 +02:00
Jason A. Donenfeld
9ea9a92117 tun: windows: spin for a bit before falling back to event object 2019-07-18 10:25:20 +02:00
Simon Rozman
2e24e7dcae tun: windows: implement ring buffers
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-07-17 14:32:13 +02:00
Jason A. Donenfeld
a961aacc9f device: immediately rekey all peers after changing device private key
Reported-by: Derrick Pallas <derrick@pallas.us>
2019-07-11 17:37:35 +02:00
Jason A. Donenfeld
b0cf53b078 README: update windows info 2019-07-08 14:52:49 +02:00
Jason A. Donenfeld
5c3d333f10 tun: windows: registration of write buffer no longer required 2019-07-05 14:17:48 +02:00
Jason A. Donenfeld
d8448f8a02 tun: windows: decrease alignment to 4 2019-07-05 07:53:19 +02:00
Jason A. Donenfeld
13abbdf14b tun: windows: delay initial write
Otherwise we provoke Wintun 0.3.
2019-07-04 22:41:42 +02:00
Jason A. Donenfeld
f361e59001 device: receive: uniform message for source address check 2019-07-01 15:24:50 +02:00
Jason A. Donenfeld
b844f1b3cc tun: windows: packetNum is unused 2019-07-01 15:23:44 +02:00
Jason A. Donenfeld
dd8817f50e device: receive: simplify flush loop 2019-07-01 15:23:24 +02:00
Jason A. Donenfeld
5e6eff81b6 tun: windows: inform wintun of maximum buffer length for writes 2019-06-26 13:27:48 +02:00
Jason A. Donenfeld
c69d026649 tun: windows: never retry open on Windows 10 2019-06-18 17:51:29 +02:00
Matt Layher
1f48971a80 tun: remove TUN prefix from types to reduce stutter elsewhere
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-14 18:35:57 +02:00
Jason A. Donenfeld
3371f8dac6 device: update transfer counters correctly
The rule is to always update them to the full packet size minus UDP/IP
encapsulation for all authenticated packet types.
2019-06-11 18:13:52 +02:00
Jason A. Donenfeld
41fdbf0971 wintun: increase registry timeout 2019-06-11 00:33:07 +02:00
Jason A. Donenfeld
03eee4a778 wintun: add helper for cleaning up 2019-06-10 11:34:59 +02:00
Jason A. Donenfeld
700860f8e6 wintun: simplify error matching and remove dumb comments 2019-06-10 11:10:49 +02:00
Jason A. Donenfeld
a304f69e0d wintun: fix comments and remove hwnd param
This now looks more idiomatic.
2019-06-10 11:03:36 +02:00
Simon Rozman
baafe92888 setupapi: add SetDeviceRegistryPropertyString description
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-10 10:43:04 +02:00
Simon Rozman
a1a97d1e41 setupapi: unify ERROR_INSUFFICIENT_BUFFER handling
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-10 10:43:03 +02:00
Jason A. Donenfeld
e924280baa wintun: allow controlling GUID 2019-06-10 10:43:02 +02:00
Jason A. Donenfeld
bb3f1932fa setupapi: add DeviceInstanceID() 2019-06-10 10:43:01 +02:00
Jason A. Donenfeld
eaf17becfa global: fixup TODO comment spacing 2019-06-06 23:00:15 +02:00
Jason A. Donenfeld
6d8b68c8f3 wintun: guid functions are upstream 2019-06-06 22:39:20 +02:00
Simon Rozman
c2ed133df8 wintun: simplify DeleteInterface method signature
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-06 08:58:26 +02:00
Jason A. Donenfeld
108c37a056 wintun: don't run HrRenameConnection in separate thread
It's very slow, but unfortunately we haven't a choice. NLA needs this to
have completed.
2019-06-05 13:09:20 +02:00
Simon Rozman
e4b0ef29a1 tun: windows: obsolete 256 packets per exchange buffer limitation
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-05 11:55:28 +02:00
Simon Rozman
625e445b22 setupapi, wintun: replace syscall with golang.org/x/sys/windows
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-04 14:54:56 +02:00
Simon Rozman
85b85e62e5 wintun: set DI_QUIETINSTALL flag for GUI-less device management
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-04 14:45:23 +02:00
Simon Rozman
014f736480 setupapi: define PropChangeParams struct
This structure is required for calling DIF_PROPERTYCHANGE installer
class.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-06-04 14:45:23 +02:00
Matt Layher
43a4589043 device: remove redundant return statements
More staticcheck fixes:

$ staticcheck ./... | grep S1023
device/noise-helpers.go:45:2: redundant return statement (S1023)
device/noise-helpers.go:54:2: redundant return statement (S1023)
device/noise-helpers.go:64:2: redundant return statement (S1023)

Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-04 13:01:52 +02:00
Matt Layher
8d76ac8cc4 device: use bytes.Equal for equality check, simplify assertEqual
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-04 13:01:52 +02:00
Matt Layher
18b6627f33 device, ratelimiter: replace uses of time.Now().Sub() with time.Since()
Simplification found by staticcheck:

$ staticcheck ./... | grep S1012
device/cookie.go:90:5: should use time.Since instead of time.Now().Sub (S1012)
device/cookie.go:127:5: should use time.Since instead of time.Now().Sub (S1012)
device/cookie.go:242:5: should use time.Since instead of time.Now().Sub (S1012)
device/noise-protocol.go:304:13: should use time.Since instead of time.Now().Sub (S1012)
device/receive.go:82:46: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:132:5: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:139:5: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:235:59: should use time.Since instead of time.Now().Sub (S1012)
device/send.go:393:9: should use time.Since instead of time.Now().Sub (S1012)
ratelimiter/ratelimiter.go:79:10: should use time.Since instead of time.Now().Sub (S1012)
ratelimiter/ratelimiter.go:87:10: should use time.Since instead of time.Now().Sub (S1012)

Change applied using:

$ find . -type f -name "*.go" -exec sed -i "s/Now().Sub(/Since(/g" {} \;

Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-03 22:15:41 +02:00
Matt Layher
80ef2a42e6 ipc/winpipe: go fmt
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-06-03 22:15:36 +02:00
Jason A. Donenfeld
da61947ec3 tun: windows: mitigate infinite loop in Flush()
It's possible that for whatever reason, we keep returning EOF, resulting
in repeated close/open/write operations, except with empty packets.
2019-05-31 16:55:03 +02:00
Jason A. Donenfeld
d9f995209c device: add SendKeepalivesToPeersWithCurrentKeypair for handover 2019-05-30 15:16:16 +02:00
Jason A. Donenfeld
d0ab883ada tai64n: account for whitening in test 2019-05-29 18:44:53 +02:00
Matt Layher
32912dc778 device, tun: rearrange code and fix device tests
Signed-off-by: Matt Layher <mdlayher@gmail.com>
2019-05-29 18:34:55 +02:00
Jason A. Donenfeld
d4034e5f8a wintun: remove extra / 2019-05-26 02:20:01 +02:00
Jason A. Donenfeld
fbcd995ec1 device: darwin actually doesn't need bound interfaces 2019-05-25 18:10:52 +02:00
Jason A. Donenfeld
e7e286ba6c device: make initiations per second match kernel implementation 2019-05-25 02:07:18 +02:00
Jason A. Donenfeld
f70546bc2e device: timers: add jitter on ack failure reinitiation 2019-05-24 13:48:25 +02:00
Simon Rozman
6a0a3a5406 wintun: revise GetInterface()
- Make foreign interface found error numeric to ease condition
  detection.
- Update GetInterface() documentation.
- Make tun.CreateTUN() quit when foreign interface found before
  attempting to create a Wintun interface with a duplicate name.
  Creation is futile.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-24 09:29:57 +02:00
Jason A. Donenfeld
8fdcf5ee30 wintun: never return nil, nil 2019-05-23 15:25:53 +02:00
Jason A. Donenfeld
a74a29bc93 ipc: use simplified fork of winio 2019-05-23 15:16:02 +02:00
Simon Rozman
dc9bbec9db setupapi: trim "Get" from getters
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-22 19:31:52 +02:00
Jason A. Donenfeld
a6dbe4f475 wintun: don't try to flush interface, but rather delete 2019-05-17 16:06:02 +02:00
Jason A. Donenfeld
c718f3940d device: fail to give bind if it doesn't exist 2019-05-17 15:35:20 +02:00
Jason A. Donenfeld
95c70b8032 wintun: make certain methods private 2019-05-17 15:01:08 +02:00
Jason A. Donenfeld
583ebe99f1 version: bump snapshot 2019-05-17 10:28:04 +02:00
Jason A. Donenfeld
a6dd282600 makefile: do not show warning on non-linux 2019-05-17 10:27:51 +02:00
Simon Rozman
7d5f5bcc0d wintun: change acronyms to uppercase
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-17 10:22:34 +02:00
Jason A. Donenfeld
3bf41b06ae global: regroup all imports 2019-05-14 09:09:52 +02:00
Jason A. Donenfeld
3147f00089 wintun: registry: fix nits 2019-05-11 17:25:48 +02:00
Simon Rozman
6c1b66802f wintun: registry: revise value reading
- Make getStringValueRetry() reusable for reading any value type. This
  merges code from GetIntegerValueWait().
- expandString() >> toString() and extend to support REG_MULTI_SZ
  (to return first value of REG_MULTI_SZ). Furthermore, doing our own
  UTF-16 to UTF-8 conversion works around a bug in windows/registry's
  GetStringValue() non-zero terminated string handling.
- Provide toInteger() analogous to toString()
- GetStringValueWait() tolerates and reads REG_MULTI_SZ too now. It
  returns REG_MULTI_SZ[0], making GetFirstStringValueWait() redundant.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-11 17:14:37 +02:00
Jason A. Donenfeld
5669ed326f wintun: call HrRenameConnection in another thread 2019-05-10 21:31:37 +02:00
Jason A. Donenfeld
2d847a38a2 wintun: add LUID accessor 2019-05-10 21:30:23 +02:00
Jason A. Donenfeld
7a8553aef0 wintun: enumerate faster by using COMPATDRIVER instead of CLASSDRIVER 2019-05-10 20:30:59 +02:00
Jason A. Donenfeld
a6045ac042 wintun: destroy devinfolist after usage 2019-05-10 20:19:11 +02:00
Simon Rozman
1c92b48415 wintun: registry: replace REG_NOTIFY with NOTIFY
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-10 18:09:20 +02:00
Jason A. Donenfeld
c267965bf8 wintun: IpConfig is a MULTI_SZ, and fix errors 2019-05-10 18:06:49 +02:00
Jason A. Donenfeld
1bf1dadf15 wintun: poll for device key
It's actually pretty hard to guess where it is.
2019-05-10 17:34:03 +02:00
Jason A. Donenfeld
f9dcfccbb7 wintun: fix scope of error object 2019-05-10 16:59:24 +02:00
Simon Rozman
7e962a9932 wintun: wait for interface registry key on device creation
By using RegNotifyChangeKeyValue(). Also disable dead gateway detection.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-10 16:43:58 +02:00
Jason A. Donenfeld
586112b5d7 conn: remove scope when sanity checking IP address format 2019-05-09 15:42:35 +02:00
Simon Rozman
dcb8f1aa6b wintun: fix GUID leading zero padding
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-09 12:16:21 +02:00
Jason A. Donenfeld
b16b0e4cf7 mod: update deps 2019-05-03 09:37:29 +02:00
Jason A. Donenfeld
81ca08f1b3 setupapi: safer aliasing of slice types 2019-05-03 09:34:00 +02:00
Jason A. Donenfeld
2e988467c2 wintun: work around GetInterface staleness bug 2019-05-03 00:42:36 +02:00
Jason A. Donenfeld
46dbf54040 wintun: don't retry when not creating
The only time we're trying to counteract the race condition is when
we're creating a driver. When we're simply looking up all drivers, it
doesn't make sense to retry.
2019-05-02 23:53:15 +02:00
Jason A. Donenfeld
247e14693a wintun: try harder to open registry key
This sucks. Can we please find a deterministic way of doing this
instead?
2019-04-29 14:00:49 +02:00
Jason A. Donenfeld
3945a299ff go.mod: use vendored winio 2019-04-29 08:09:38 +02:00
Jason A. Donenfeld
bb42ec7d18 tun: freebsd: work around numerous kernel panics on shutdown
There are numerous race conditions. But even this will crash it:

while true; do ifconfig tun0 create; ifconfig tun0 destroy; done

It seems like LLv6 is related, which we're not using anyway, so
explicitly disable it on the interface.
2019-04-23 18:00:23 +09:00
Simon Rozman
f1dc167901 setupapi: Fix struct size mismatches
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-19 10:08:11 +02:00
Jason A. Donenfeld
c7a26dfef3 setupapi: actually fix padding by rounding up to sizeof(void*) 2019-04-19 10:19:00 +09:00
Jason A. Donenfeld
d024393335 tun: darwin: write routeSocket variable in helper
Otherwise the race detector "complains".
2019-04-19 07:53:19 +09:00
Jason A. Donenfeld
d9078fe772 main: revise warnings 2019-04-19 07:48:09 +09:00
Jason A. Donenfeld
d3dd991e4e device: send: check packet length before freeing element 2019-04-18 23:23:03 +09:00
Simon Rozman
5811447b38 setupapi: Revise DrvInfoDetailData struct size calculation
Go adds trailing padding to DrvInfoDetailData struct in GOARCH=386 which
confuses SetupAPI expecting exactly sizeof(SP_DRVINFO_DETAIL_DATA).

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-18 10:39:22 +02:00
Jason A. Donenfeld
e0a8c22aa6 windows: use proper constants from updated x/sys 2019-04-13 02:02:02 +02:00
Jason A. Donenfeld
0b77bf78cd conn: linux: RTA_MARK has moved to x/sys 2019-04-13 02:01:20 +02:00
Simon Rozman
ef5f3ad80a tun: windows: Adopt new error codes returned by Wintun
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-11 19:38:11 +02:00
Simon Rozman
a291fdd746 tun: windows: do not sleep after OPERATION_ABORTED on write
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-11 19:37:04 +02:00
Jason A. Donenfeld
d50e390904 main_windows: use proper version constant 2019-04-09 10:45:40 +02:00
62 changed files with 2993 additions and 1262 deletions

View File

@@ -8,7 +8,7 @@ all: generate-version-and-build
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|) ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.) $(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
else else ifeq ($(shell go env GOOS),linux)
ireallywantobuildon_linux.go: ireallywantobuildon_linux.go:
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2 @printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@" @printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"

View File

@@ -2,8 +2,6 @@
This is an implementation of WireGuard in Go. This is an implementation of WireGuard in Go.
***WARNING:*** This is a work in progress and not ready for prime time, with no official "releases" yet. It is extremely rough around the edges and leaves much to be desired. There are bugs and we are not yet in a position to make claims about its security. Beware.
## Usage ## Usage
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run: Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
@@ -36,7 +34,7 @@ This runs on macOS using the utun driver. It does not yet support sticky sockets
### Windows ### Windows
It is currently a work in progress to strip out the beginnings of an experiment done with the OpenVPN tuntap driver and instead port to the new UWP APIs for tunnels. In other words, this does not *yet* work on Windows. This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
### FreeBSD ### FreeBSD

View File

@@ -5,8 +5,14 @@
package device package device
import "errors"
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) { func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() nb, ok := device.net.bind.(*nativeBind)
if !ok {
return 0, errors.New("no socket exists")
}
sysconn, err := nb.ipv4.SyscallConn()
if err != nil { if err != nil {
return return
} }
@@ -20,7 +26,11 @@ func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
} }
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) { func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() nb, ok := device.net.bind.(*nativeBind)
if !ok {
return 0, errors.New("no socket exists")
}
sysconn, err := nb.ipv6.SyscallConn()
if err != nil { if err != nil {
return return
} }

View File

@@ -1,44 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"golang.org/x/sys/unix"
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return nil
}
err2 := sysconn.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil {
return nil
}
err2 := sysconn.Control(func(fd uintptr) {
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(interfaceIndex))
})
if err2 != nil {
return err2
}
if err != nil {
return err
}
return nil
}

View File

@@ -7,8 +7,10 @@ package device
import ( import (
"encoding/binary" "encoding/binary"
"golang.org/x/sys/windows" "errors"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
const ( const (
@@ -22,6 +24,10 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
binary.BigEndian.PutUint32(bytes, interfaceIndex) binary.BigEndian.PutUint32(bytes, interfaceIndex)
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0])) interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
if device.net.bind == nil {
return errors.New("Bind is not yet initialized")
}
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn() sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err

View File

@@ -7,9 +7,11 @@ package device
import ( import (
"errors" "errors"
"net"
"strings"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"net"
) )
const ( const (
@@ -41,13 +43,18 @@ type Endpoint interface {
} }
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
host, _, err := net.SplitHostPort(s) host, _, err := net.SplitHostPort(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil { if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host) return nil, errors.New("Failed to parse IP address: " + host)
} }

View File

@@ -18,13 +18,14 @@ package device
import ( import (
"errors" "errors"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
const ( const (
@@ -719,7 +720,7 @@ func (bind *nativeBind) routineRouteListener(device *Device) {
peer.endpoint.(*NativeEndpoint).src4().src, peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{ unix.RtAttr{
Len: 8, Len: 8,
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix Type: unix.RTA_MARK,
}, },
uint32(bind.lastMark), uint32(bind.lastMark),
} }

View File

@@ -22,7 +22,7 @@ const (
RejectAfterTime = time.Second * 180 RejectAfterTime = time.Second * 180
KeepaliveTimeout = time.Second * 10 KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120 CookieRefreshTime = time.Second * 120
HandshakeInitationRate = time.Second / 20 HandshakeInitationRate = time.Second / 50
PaddingMultiple = 16 PaddingMultiple = 16
) )

View File

@@ -8,10 +8,11 @@ package device
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
) )
type CookieChecker struct { type CookieChecker struct {
@@ -86,7 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.RLock() st.RLock()
defer st.RUnlock() defer st.RUnlock()
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { if time.Since(st.mac2.secretSet) > CookieRefreshTime {
return false return false
} }
@@ -123,7 +124,7 @@ func (st *CookieChecker) CreateReply(
// refresh cookie secret // refresh cookie secret
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime { if time.Since(st.mac2.secretSet) > CookieRefreshTime {
st.RUnlock() st.RUnlock()
st.Lock() st.Lock()
_, err := rand.Read(st.mac2.secret[:]) _, err := rand.Read(st.mac2.secret[:])
@@ -238,7 +239,7 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
// set mac2 // set mac2
if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime { if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
return return
} }

View File

@@ -6,12 +6,13 @@
package device package device
import ( import (
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/tun"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/tun"
) )
const ( const (
@@ -85,7 +86,7 @@ type Device struct {
} }
tun struct { tun struct {
device tun.TUNDevice device tun.Device
mtu int32 mtu int32
} }
} }
@@ -132,6 +133,7 @@ func deviceUpdateState(device *Device) {
switch newIsUp { switch newIsUp {
case true: case true:
if err := device.BindUpdate(); err != nil { if err := device.BindUpdate(); err != nil {
device.log.Error.Printf("Unable to update bind: %v\n", err)
device.isUp.Set(false) device.isUp.Set(false)
break break
} }
@@ -199,18 +201,22 @@ func (device *Device) IsUnderLoad() bool {
} }
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error { func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources // lock required resources
device.staticIdentity.Lock() device.staticIdentity.Lock()
defer device.staticIdentity.Unlock() defer device.staticIdentity.Unlock()
if sk.Equals(device.staticIdentity.privateKey) {
return nil
}
device.peers.Lock() device.peers.Lock()
defer device.peers.Unlock() defer device.peers.Unlock()
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
defer peer.handshake.mutex.RUnlock() lockedPeers = append(lockedPeers, peer)
} }
// remove peers with matching public keys // remove peers with matching public keys
@@ -232,8 +238,8 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
rmKey := device.staticIdentity.privateKey.IsZero() rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap { for key, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
if rmKey { if rmKey {
@@ -244,13 +250,22 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
if isZero(handshake.precomputedStaticStatic[:]) { if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key) unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
} }
} }
for _, peer := range lockedPeers {
peer.handshake.mutex.RUnlock()
}
for _, peer := range expiredPeers {
peer.ExpireCurrentKeypairs()
}
return nil return nil
} }
func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device { func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
device := new(Device) device := new(Device)
device.isUp.Set(false) device.isUp.Set(false)
@@ -322,7 +337,6 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
func (device *Device) RemovePeer(key NoisePublicKey) { func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock() device.peers.Lock()
defer device.peers.Unlock() defer device.peers.Unlock()
// stop peer and remove from routing // stop peer and remove from routing
peer, ok := device.peers.keyMap[key] peer, ok := device.peers.keyMap[key]
@@ -394,3 +408,20 @@ func (device *Device) Close() {
func (device *Device) Wait() chan struct{} { func (device *Device) Wait() chan struct{} {
return device.signals.stop return device.signals.stop
} }
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
if device.isClosed.Get() {
return
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.keypairs.RLock()
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
peer.keypairs.RUnlock()
if sendKeepalive {
peer.SendKeepalive()
}
}
device.peers.RUnlock()
}

View File

@@ -9,21 +9,17 @@ package device
* without network dependencies * without network dependencies
*/ */
import "testing" import (
"bytes"
"testing"
)
func TestDevice(t *testing.T) { func TestDevice(t *testing.T) {
// prepare tun devices for generating traffic // prepare tun devices for generating traffic
tun1, err := CreateDummyTUN("tun1") tun1 := newDummyTUN("tun1")
if err != nil { tun2 := newDummyTUN("tun2")
t.Error("failed to create tun:", err.Error())
}
tun2, err := CreateDummyTUN("tun2")
if err != nil {
t.Error("failed to create tun:", err.Error())
}
_ = tun1 _ = tun1
_ = tun2 _ = tun2
@@ -46,3 +42,27 @@ func TestDevice(t *testing.T) {
// create binds // create binds
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a, b []byte) {
if !bytes.Equal(a, b) {
t.Fatal(a, "!=", b)
}
}

View File

@@ -7,8 +7,9 @@ package device
import ( import (
"encoding/hex" "encoding/hex"
"golang.org/x/crypto/blake2s"
"testing" "testing"
"golang.org/x/crypto/blake2s"
) )
type KDFTest struct { type KDFTest struct {

View File

@@ -7,9 +7,10 @@ package device
import ( import (
"crypto/cipher" "crypto/cipher"
"golang.zx2c4.com/wireguard/replay"
"sync" "sync"
"time" "time"
"golang.zx2c4.com/wireguard/replay"
) )
/* Due to limitations in Go and /x/crypto there is currently /* Due to limitations in Go and /x/crypto there is currently

View File

@@ -8,8 +8,9 @@
package device package device
import ( import (
"golang.org/x/sys/unix"
"runtime" "runtime"
"golang.org/x/sys/unix"
) )
var fwmarkIoctl int var fwmarkIoctl int

View File

@@ -9,9 +9,10 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"hash"
) )
/* KDF related functions. /* KDF related functions.
@@ -41,7 +42,6 @@ func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) { func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
HMAC1(t0, key, input) HMAC1(t0, key, input)
HMAC1(t0, t0[:], []byte{0x1}) HMAC1(t0, t0[:], []byte{0x1})
return
} }
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) { func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
@@ -50,7 +50,6 @@ func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
HMAC1(t0, prk[:], []byte{0x1}) HMAC1(t0, prk[:], []byte{0x1})
HMAC2(t1, prk[:], t0[:], []byte{0x2}) HMAC2(t1, prk[:], t0[:], []byte{0x2})
setZero(prk[:]) setZero(prk[:])
return
} }
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) { func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
@@ -60,7 +59,6 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
HMAC2(t1, prk[:], t0[:], []byte{0x2}) HMAC2(t1, prk[:], t0[:], []byte{0x2})
HMAC2(t2, prk[:], t1[:], []byte{0x3}) HMAC2(t2, prk[:], t1[:], []byte{0x3})
setZero(prk[:]) setZero(prk[:])
return
} }
func isZero(val []byte) bool { func isZero(val []byte) bool {

View File

@@ -7,12 +7,13 @@ package device
import ( import (
"errors" "errors"
"sync"
"time"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
"sync"
"time"
) )
const ( const (
@@ -300,7 +301,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var ok bool var ok bool
ok = timestamp.After(handshake.lastTimestamp) ok = timestamp.After(handshake.lastTimestamp)
ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
if !ok { if !ok {
return nil return nil

View File

@@ -9,6 +9,7 @@ import (
"crypto/subtle" "crypto/subtle"
"encoding/hex" "encoding/hex"
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )

View File

@@ -10,6 +10,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"sync/atomic"
"time" "time"
) )
@@ -67,7 +68,6 @@ type Peer struct {
} }
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) { func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed.Get() { if device.isClosed.Get() {
return nil, errors.New("device closed") return nil, errors.New("device closed")
} }
@@ -102,20 +102,28 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if ok { if ok {
return nil, errors.New("adding existing peer") return nil, errors.New("adding existing peer")
} }
device.peers.keyMap[pk] = peer
// pre-compute DH // pre-compute DH
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
// reset endpoint // reset endpoint
peer.endpoint = nil peer.endpoint = nil
// conditionally add
if !ssIsZero {
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer // start peer
if peer.device.isUp.Get() { if peer.device.isUp.Get() {
@@ -140,7 +148,11 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
return errors.New("no known endpoint for peer") return errors.New("no known endpoint for peer")
} }
return peer.device.net.bind.Send(buffer, peer.endpoint) err := peer.device.net.bind.Send(buffer, peer.endpoint)
if err == nil {
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
}
return err
} }
func (peer *Peer) String() string { func (peer *Peer) String() string {
@@ -227,6 +239,25 @@ func (peer *Peer) ZeroAndFlushAll() {
peer.FlushNonceQueue() peer.FlushNonceQueue()
} }
func (peer *Peer) ExpireCurrentKeypairs() {
handshake := &peer.handshake
handshake.mutex.Lock()
peer.device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
handshake.mutex.Unlock()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
keypairs := &peer.keypairs
keypairs.Lock()
if keypairs.current != nil {
keypairs.current.sendNonce = RejectAfterMessages
}
if keypairs.next != nil {
keypairs.next.sendNonce = RejectAfterMessages
}
keypairs.Unlock()
}
func (peer *Peer) Stop() { func (peer *Peer) Stop() {
// prevent simultaneous start/stop operations // prevent simultaneous start/stop operations

View File

@@ -8,14 +8,15 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
type QueueHandshakeElement struct { type QueueHandshakeElement struct {
@@ -78,7 +79,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
return return
} }
keypair := peer.keypairs.Current() keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) { if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Set(true) peer.timers.sentLastMinuteHandshake.Set(true)
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
@@ -426,6 +427,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
logDebug.Println(peer, "- Received handshake initiation") logDebug.Println(peer, "- Received handshake initiation")
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
peer.SendHandshakeResponse() peer.SendHandshakeResponse()
@@ -456,6 +458,7 @@ func (device *Device) RoutineHandshake() {
peer.SetEndpointFromPacket(elem.endpoint) peer.SetEndpointFromPacket(elem.endpoint)
logDebug.Println(peer, "- Received handshake response") logDebug.Println(peer, "- Received handshake response")
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
// update timers // update timers
@@ -482,33 +485,6 @@ func (device *Device) RoutineHandshake() {
} }
} }
func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) {
if !*shouldFlush {
select {
case <-peer.routines.stop:
stop = true
return
case elem, elemOk = <-peer.queue.inbound:
return
}
} else {
select {
case <-peer.routines.stop:
stop = true
return
case elem, elemOk = <-peer.queue.inbound:
return
default:
*shouldFlush = false
err := peer.device.tun.device.Flush()
if err != nil {
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
}
return peer.elementStopOrFlush(shouldFlush)
}
}
}
func (peer *Peer) RoutineSequentialReceiver() { func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device device := peer.device
@@ -517,10 +493,6 @@ func (peer *Peer) RoutineSequentialReceiver() {
logDebug := device.log.Debug logDebug := device.log.Debug
var elem *QueueInboundElement var elem *QueueInboundElement
var ok bool
var stop bool
shouldFlush := false
defer func() { defer func() {
logDebug.Println(peer, "- Routine: sequential receiver - stopped") logDebug.Println(peer, "- Routine: sequential receiver - stopped")
@@ -546,9 +518,14 @@ func (peer *Peer) RoutineSequentialReceiver() {
elem = nil elem = nil
} }
stop, ok, elem = peer.elementStopOrFlush(&shouldFlush) var elemOk bool
if stop || !ok { select {
case <-peer.routines.stop:
return return
case elem, elemOk = <-peer.queue.inbound:
if !elemOk {
return
}
} }
// wait for decryption // wait for decryption
@@ -580,6 +557,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.keepKeyFreshReceiving() peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived() peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
// check for keepalive // check for keepalive
@@ -641,8 +619,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len] src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer { if device.allowedips.LookupIPv6(src) != peer {
logInfo.Println( logInfo.Println(
"IPv6 packet with disallowed source address from",
peer, peer,
"sent packet with disallowed IPv6 source",
) )
continue continue
} }
@@ -655,10 +633,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
// write to tun device // write to tun device
offset := MessageTransportOffsetContent offset := MessageTransportOffsetContent
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset) _, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
if err == nil { if len(peer.queue.inbound) == 0 {
shouldFlush = true err = device.tun.device.Flush()
if err != nil {
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
}
} }
if err != nil && !device.isClosed.Get() { if err != nil && !device.isClosed.Get() {
logError.Println("Failed to write packet to TUN device:", err) logError.Println("Failed to write packet to TUN device:", err)

View File

@@ -8,13 +8,14 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
/* Outbound flow /* Outbound flow
@@ -128,14 +129,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
} }
peer.handshake.mutex.RLock() peer.handshake.mutex.RLock()
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.RUnlock() peer.handshake.mutex.RUnlock()
return nil return nil
} }
peer.handshake.mutex.RUnlock() peer.handshake.mutex.RUnlock()
peer.handshake.mutex.Lock() peer.handshake.mutex.Lock()
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout { if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.Unlock() peer.handshake.mutex.Unlock()
return nil return nil
} }
@@ -231,7 +232,7 @@ func (peer *Peer) keepKeyFreshSending() {
return return
} }
nonce := atomic.LoadUint64(&keypair.sendNonce) nonce := atomic.LoadUint64(&keypair.sendNonce)
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) { if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false) peer.SendHandshakeInitiation(false)
} }
} }
@@ -389,7 +390,7 @@ func (peer *Peer) RoutineNonce() {
keypair = peer.keypairs.Current() keypair = peer.keypairs.Current()
if keypair != nil && keypair.sendNonce < RejectAfterMessages { if keypair != nil && keypair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keypair.created) < RejectAfterTime { if time.Since(keypair.created) < RejectAfterTime {
break break
} }
} }
@@ -599,19 +600,17 @@ func (peer *Peer) RoutineSequentialSender() {
// send message and return buffer to pool // send message and return buffer to pool
length := uint64(len(elem.packet))
err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent()
}
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
if err != nil { if err != nil {
logError.Println(peer, "- Failed to send data packet", err) logError.Println(peer, "- Failed to send data packet", err)
continue continue
} }
atomic.AddUint64(&peer.stats.txBytes, length)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent()
}
peer.keepKeyFreshSending() peer.keepKeyFreshSending()
} }
} }

View File

@@ -147,7 +147,7 @@ func expiredPersistentKeepalive(peer *Peer) {
/* Should be called after an authenticated data packet is sent. */ /* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() { func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() { if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout) peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
} }
} }

View File

@@ -6,8 +6,9 @@
package device package device
import ( import (
"golang.zx2c4.com/wireguard/tun"
"sync/atomic" "sync/atomic"
"golang.zx2c4.com/wireguard/tun"
) )
const DefaultMTU = 1420 const DefaultMTU = 1420
@@ -22,7 +23,7 @@ func (device *Device) RoutineTUNEventReader() {
device.state.starting.Done() device.state.starting.Done()
for event := range device.tun.device.Events() { for event := range device.tun.device.Events() {
if event&tun.TUNEventMTUUpdate != 0 { if event&tun.EventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU() mtu, err := device.tun.device.MTU()
old := atomic.LoadInt32(&device.tun.mtu) old := atomic.LoadInt32(&device.tun.mtu)
if err != nil { if err != nil {
@@ -37,13 +38,13 @@ func (device *Device) RoutineTUNEventReader() {
} }
} }
if event&tun.TUNEventUp != 0 && !setUp { if event&tun.EventUp != 0 && !setUp {
logInfo.Println("Interface set up") logInfo.Println("Interface set up")
setUp = true setUp = true
device.Up() device.Up()
} }
if event&tun.TUNEventDown != 0 && setUp { if event&tun.EventDown != 0 && setUp {
logInfo.Println("Interface set down") logInfo.Println("Interface set down")
setUp = false setUp = false
device.Down() device.Down()

56
device/tun_test.go Normal file
View File

@@ -0,0 +1,56 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"os"
"golang.zx2c4.com/wireguard/tun"
)
// newDummyTUN creates a dummy TUN device with the specified name.
func newDummyTUN(name string) tun.Device {
return &dummyTUN{
name: name,
packets: make(chan []byte, 100),
events: make(chan tun.Event, 10),
}
}
// A dummyTUN is a tun.Device which is used in unit tests.
type dummyTUN struct {
name string
mtu int
packets chan []byte
events chan tun.Event
}
func (d *dummyTUN) Events() chan tun.Event { return d.events }
func (*dummyTUN) File() *os.File { return nil }
func (*dummyTUN) Flush() error { return nil }
func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
func (d *dummyTUN) Name() (string, error) { return d.name, nil }
func (d *dummyTUN) Close() error {
close(d.events)
close(d.packets)
return nil
}
func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
buf, ok := <-d.packets
if !ok {
return 0, errors.New("device closed")
}
copy(b[offset:], buf)
return len(buf), nil
}
func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
d.packets <- b[offset:]
return len(b), nil
}

View File

@@ -8,13 +8,14 @@ package device
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"golang.zx2c4.com/wireguard/ipc"
"io" "io"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/ipc"
) )
type IPCError struct { type IPCError struct {
@@ -242,7 +243,12 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
return &IPCError{ipc.IpcErrorInvalid} return &IPCError{ipc.IpcErrorInvalid}
} }
logDebug.Println(peer, "- UAPI: Created") if peer == nil {
dummy = true
peer = &Peer{}
} else {
logDebug.Println(peer, "- UAPI: Created")
}
} }
case "remove": case "remove":

View File

@@ -1,3 +1,3 @@
package device package device
const WireGuardGoVersion = "0.0.20190409" const WireGuardGoVersion = "0.0.20190805"

9
go.mod
View File

@@ -1,8 +1,9 @@
module golang.zx2c4.com/wireguard module golang.zx2c4.com/wireguard
go 1.12
require ( require (
github.com/Microsoft/go-winio v0.4.12 golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576 golang.org/x/net v0.0.0-20190613194153-d28f0bde5980
golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53 golang.org/x/sys v0.0.0-20190618155005-516e3c20635f
golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54
) )

16
go.sum
View File

@@ -1,11 +1,11 @@
github.com/Microsoft/go-winio v0.4.12 h1:xAfWHN1IrQ0NJ9TBC0KBZoqLjzDTr1ML+4MywiUOryc=
github.com/Microsoft/go-winio v0.4.12/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576 h1:aUX/1G2gFSs4AsJJg2cL3HuoRhCSCz733FE5GUSuaT4= golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56 h1:ZpKuNIejY8P0ExLOVyKhb0WsgG8UdvHXe6TWjY7eL6k=
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190617133340-57b3e21c3d56/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53 h1:kcXqo9vE6fsZY5X5Rd7R1l7fTgnWaDCVmln65REefiE= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190613194153-d28f0bde5980 h1:dfGZHvZk057jK2MCeWus/TowKpJ8y4AmooUzdBSR9GU=
golang.org/x/net v0.0.0-20190613194153-d28f0bde5980/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54 h1:xe1/2UUJRmA9iDglQSlkx8c5n3twv58+K0mPpC2zmhA= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190618155005-516e3c20635f h1:dHNZYIYdq2QuU6w73vZ/DzesPbVlZVYZTtTZmrnsbQ8=
golang.org/x/sys v0.0.0-20190618155005-516e3c20635f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=

View File

@@ -10,11 +10,12 @@ package ipc
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"path" "path"
"unsafe" "unsafe"
"golang.org/x/sys/unix"
) )
var socketDirectory = "/var/run/wireguard" var socketDirectory = "/var/run/wireguard"

View File

@@ -8,11 +8,12 @@ package ipc
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"os" "os"
"path" "path"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
var socketDirectory = "/var/run/wireguard" var socketDirectory = "/var/run/wireguard"

View File

@@ -6,11 +6,12 @@
package ipc package ipc
import ( import (
"github.com/Microsoft/go-winio"
"net" "net"
"golang.zx2c4.com/wireguard/ipc/winpipe"
) )
//TODO: replace these with actual standard windows error numbers from the win package // TODO: replace these with actual standard windows error numbers from the win package
const ( const (
IpcErrorIO = -int64(5) IpcErrorIO = -int64(5)
IpcErrorProtocol = -int64(71) IpcErrorProtocol = -int64(71)
@@ -46,22 +47,14 @@ func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr() return l.listener.Addr()
} }
func GetSystemSecurityDescriptor() string { /* SDDL_DEVOBJ_SYS_ALL from the WDK */
// var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
// SDDL encoded.
//
// (system = SECURITY_NT_AUTHORITY | SECURITY_LOCAL_SYSTEM_RID)
// owner: system
// grant: GENERIC_ALL to system
//
return "O:SYD:(A;;GA;;;SY)"
}
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
config := winio.PipeConfig{ config := winpipe.PipeConfig{
SecurityDescriptor: GetSystemSecurityDescriptor(), SecurityDescriptor: UAPISecurityDescriptor,
} }
listener, err := winio.ListenPipe("\\\\.\\pipe\\WireGuard\\"+name, &config) listener, err := winpipe.ListenPipe("\\\\.\\pipe\\WireGuard\\"+name, &config)
if err != nil { if err != nil {
return nil, err return nil, err
} }

322
ipc/winpipe/file.go Normal file
View File

@@ -0,0 +1,322 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"errors"
"io"
"runtime"
"sync"
"sync/atomic"
"syscall"
"time"
)
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
func (b *atomicBool) swap(new bool) bool {
var newInt int32
if new {
newInt = 1
}
return atomic.SwapInt32((*int32)(b), newInt) == 1
}
const (
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
)
var (
ErrFileClosed = errors.New("file has already been closed")
ErrTimeout = &timeoutError{}
)
type timeoutError struct{}
func (e *timeoutError) Error() string { return "i/o timeout" }
func (e *timeoutError) Timeout() bool { return true }
func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{}
var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle
// ioResult contains the result of an asynchronous IO operation
type ioResult struct {
bytes uint32
err error
}
// ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct {
o syscall.Overlapped
ch chan ioResult
}
func initIo() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
if err != nil {
panic(err)
}
ioCompletionPort = h
go ioCompletionProcessor(h)
}
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct {
handle syscall.Handle
wg sync.WaitGroup
wgLock sync.RWMutex
closing atomicBool
socket bool
readDeadline deadlineHandler
writeDeadline deadlineHandler
}
type deadlineHandler struct {
setLock sync.Mutex
channel timeoutChan
channelLock sync.RWMutex
timer *time.Timer
timedout atomicBool
}
// makeWin32File makes a new win32File from an existing file handle
func makeWin32File(h syscall.Handle) (*win32File, error) {
f := &win32File{handle: h}
ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
if err != nil {
return nil, err
}
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE)
if err != nil {
return nil, err
}
f.readDeadline.channel = make(timeoutChan)
f.writeDeadline.channel = make(timeoutChan)
return f, nil
}
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h)
}
// closeHandle closes the resources associated with a Win32 handle
func (f *win32File) closeHandle() {
f.wgLock.Lock()
// Atomically set that we are closing, releasing the resources only once.
if !f.closing.swap(true) {
f.wgLock.Unlock()
// cancel all IO and wait for it to complete
cancelIoEx(f.handle, nil)
f.wg.Wait()
// at this point, no new IO can start
syscall.Close(f.handle)
f.handle = 0
} else {
f.wgLock.Unlock()
}
}
// Close closes a win32File.
func (f *win32File) Close() error {
f.closeHandle()
return nil
}
// prepareIo prepares for a new IO operation.
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
func (f *win32File) prepareIo() (*ioOperation, error) {
f.wgLock.RLock()
if f.closing.isSet() {
f.wgLock.RUnlock()
return nil, ErrFileClosed
}
f.wg.Add(1)
f.wgLock.RUnlock()
c := &ioOperation{}
c.ch = make(chan ioResult)
return c, nil
}
// ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h syscall.Handle) {
for {
var bytes uint32
var key uintptr
var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
if op == nil {
panic(err)
}
op.ch <- ioResult{bytes, err}
}
}
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING {
return int(bytes), err
}
if f.closing.isSet() {
cancelIoEx(f.handle, &c.o)
}
var timeout timeoutChan
if d != nil {
d.channelLock.Lock()
timeout = d.channel
d.channelLock.Unlock()
}
var r ioResult
select {
case r = <-c.ch:
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
if f.closing.isSet() {
err = ErrFileClosed
}
} else if err != nil && f.socket {
// err is from Win32. Query the overlapped structure to get the winsock error.
var bytes, flags uint32
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
}
case <-timeout:
cancelIoEx(f.handle, &c.o)
r = <-c.ch
err = r.err
if err == syscall.ERROR_OPERATION_ABORTED {
err = ErrTimeout
}
}
// runtime.KeepAlive is needed, as c is passed via native
// code to ioCompletionProcessor, c must remain alive
// until the channel read is complete.
runtime.KeepAlive(c)
return int(r.bytes), err
}
// Read reads from a file handle.
func (f *win32File) Read(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.readDeadline.timedout.isSet() {
return 0, ErrTimeout
}
var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b)
// Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE {
return 0, io.EOF
} else {
return n, err
}
}
// Write writes to a file handle.
func (f *win32File) Write(b []byte) (int, error) {
c, err := f.prepareIo()
if err != nil {
return 0, err
}
defer f.wg.Done()
if f.writeDeadline.timedout.isSet() {
return 0, ErrTimeout
}
var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b)
return n, err
}
func (f *win32File) SetReadDeadline(deadline time.Time) error {
return f.readDeadline.set(deadline)
}
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
return f.writeDeadline.set(deadline)
}
func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle)
}
func (f *win32File) Fd() uintptr {
return uintptr(f.handle)
}
func (d *deadlineHandler) set(deadline time.Time) error {
d.setLock.Lock()
defer d.setLock.Unlock()
if d.timer != nil {
if !d.timer.Stop() {
<-d.channel
}
d.timer = nil
}
d.timedout.setFalse()
select {
case <-d.channel:
d.channelLock.Lock()
d.channel = make(chan struct{})
d.channelLock.Unlock()
default:
}
if deadline.IsZero() {
return nil
}
timeoutIO := func() {
d.timedout.setTrue()
close(d.channel)
}
now := time.Now()
duration := deadline.Sub(now)
if deadline.After(now) {
// Deadline is in the future, set a timer to wait
d.timer = time.AfterFunc(duration, timeoutIO)
} else {
// Deadline is in the past. Cancel all pending IO now.
timeoutIO()
}
return nil
}

9
ipc/winpipe/mksyscall.go Normal file
View File

@@ -0,0 +1,9 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go

516
ipc/winpipe/pipe.go Normal file
View File

@@ -0,0 +1,516 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"runtime"
"syscall"
"time"
"unsafe"
)
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
type ioStatusBlock struct {
Status, Information uintptr
}
type objectAttributes struct {
Length uintptr
RootDirectory uintptr
ObjectName *unicodeString
Attributes uintptr
SecurityDescriptor *securityDescriptor
SecurityQoS uintptr
}
type unicodeString struct {
Length uint16
MaximumLength uint16
Buffer uintptr
}
type securityDescriptor struct {
Revision byte
Sbz1 byte
Control uint16
Owner uintptr
Group uintptr
Sacl uintptr
Dacl uintptr
}
type ntstatus int32
func (status ntstatus) Err() error {
if status >= 0 {
return nil
}
return rtlNtStatusToDosError(status)
}
const (
cERROR_PIPE_BUSY = syscall.Errno(231)
cERROR_NO_DATA = syscall.Errno(232)
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0
cPIPE_TYPE_MESSAGE = 4
cPIPE_READMODE_MESSAGE = 2
cFILE_OPEN = 1
cFILE_CREATE = 2
cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
)
var (
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
// This error should match net.errClosing since docker takes a dependency on its text.
ErrPipeListenerClosed = errors.New("use of closed network connection")
errPipeWriteClosed = errors.New("pipe has been closed for write")
)
type win32Pipe struct {
*win32File
path string
}
type win32MessageBytePipe struct {
win32Pipe
writeClosed bool
readEOF bool
}
type pipeAddress string
func (f *win32Pipe) LocalAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *win32Pipe) RemoteAddr() net.Addr {
return pipeAddress(f.path)
}
func (f *win32Pipe) SetDeadline(t time.Time) error {
f.SetReadDeadline(t)
f.SetWriteDeadline(t)
return nil
}
// CloseWrite closes the write side of a message pipe in byte mode.
func (f *win32MessageBytePipe) CloseWrite() error {
if f.writeClosed {
return errPipeWriteClosed
}
err := f.win32File.Flush()
if err != nil {
return err
}
_, err = f.win32File.Write(nil)
if err != nil {
return err
}
f.writeClosed = true
return nil
}
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
// they are used to implement CloseWrite().
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
if f.writeClosed {
return 0, errPipeWriteClosed
}
if len(b) == 0 {
return 0, nil
}
return f.win32File.Write(b)
}
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
// mode pipe will return io.EOF, as will all subsequent reads.
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
if f.readEOF {
return 0, io.EOF
}
n, err := f.win32File.Read(b)
if err == io.EOF {
// If this was the result of a zero-byte read, then
// it is possible that the read was due to a zero-size
// message. Since we are simulating CloseWrite with a
// zero-byte message, ensure that all future Read() calls
// also return EOF.
f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams.
err = nil
}
return n, err
}
func (s pipeAddress) Network() string {
return "pipe"
}
func (s pipeAddress) String() string {
return string(s)
}
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
for {
select {
case <-ctx.Done():
return syscall.Handle(0), ctx.Err()
default:
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil {
return h, nil
}
if err != cERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path}
}
// Wait 10 msec and try again. This is a rather simplistic
// view, as we always try each 10 milliseconds.
time.Sleep(time.Millisecond * 10)
}
}
}
// DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration) (net.Conn, error) {
var absTimeout time.Time
if timeout != nil {
absTimeout = time.Now().Add(*timeout)
} else {
absTimeout = time.Now().Add(time.Second * 2)
}
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
conn, err := DialPipeContext(ctx, path)
if err == context.DeadlineExceeded {
return nil, ErrTimeout
}
return conn, err
}
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout.
func DialPipeContext(ctx context.Context, path string) (net.Conn, error) {
var err error
var h syscall.Handle
h, err = tryDialPipe(ctx, &path)
if err != nil {
return nil, err
}
var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil {
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
syscall.Close(h)
return nil, err
}
// If the pipe is in message mode, return a message byte pipe, which
// supports CloseWrite().
if flags&cPIPE_TYPE_MESSAGE != 0 {
return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: f, path: path},
}, nil
}
return &win32Pipe{win32File: f, path: path}, nil
}
type acceptResponse struct {
f *win32File
err error
}
type win32PipeListener struct {
firstHandle syscall.Handle
path string
config PipeConfig
acceptCh chan (chan acceptResponse)
closeCh chan int
doneCh chan int
}
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
path16, err := syscall.UTF16FromString(path)
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
var oa objectAttributes
oa.Length = unsafe.Sizeof(oa)
var ntPath unicodeString
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
defer localFree(ntPath.Buffer)
oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe.
if first {
if sd != nil {
len := uint32(len(sd))
sdb := localAlloc(0, len)
defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
} else {
// Construct the default named pipe security descriptor.
var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
}
defer localFree(dacl)
sdb := &securityDescriptor{
Revision: 1,
Control: cSE_DACL_PRESENT,
Dacl: dacl,
}
oa.SecurityDescriptor = sdb
}
}
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
if c.MessageMode {
typ |= cFILE_PIPE_MESSAGE_TYPE
}
disposition := uint32(cFILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
if first {
disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false.
access = syscall.SYNCHRONIZE
}
timeout := int64(-50 * 10000) // 50ms
var (
h syscall.Handle
iosb ioStatusBlock
)
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err}
}
runtime.KeepAlive(ntPath)
return h, nil
}
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
if err != nil {
return nil, err
}
f, err := makeWin32File(h)
if err != nil {
syscall.Close(h)
return nil, err
}
return f, nil
}
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
p, err := l.makeServerPipe()
if err != nil {
return nil, err
}
// Wait for the client to connect.
ch := make(chan error)
go func(p *win32File) {
ch <- connectPipe(p)
}(p)
select {
case err = <-ch:
if err != nil {
p.Close()
p = nil
}
case <-l.closeCh:
// Abort the connect request by closing the handle.
p.Close()
p = nil
err = <-ch
if err == nil || err == ErrFileClosed {
err = ErrPipeListenerClosed
}
}
return p, err
}
func (l *win32PipeListener) listenerRoutine() {
closed := false
for !closed {
select {
case <-l.closeCh:
closed = true
case responseCh := <-l.acceptCh:
var (
p *win32File
err error
)
for {
p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try
// again.
if err != cERROR_NO_DATA {
break
}
}
responseCh <- acceptResponse{p, err}
closed = err == ErrPipeListenerClosed
}
}
syscall.Close(l.firstHandle)
l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh)
}
// PipeConfig contain configuration for the pipe listener.
type PipeConfig struct {
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
SecurityDescriptor string
// MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in
// this implementation is that CloseWrite() is only supported for message mode pipes;
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
// transferred to the reader (and returned as io.EOF in this implementation)
// when the pipe is in message mode.
MessageMode bool
// InputBufferSize specifies the size the input buffer, in bytes.
InputBufferSize int32
// OutputBufferSize specifies the size the input buffer, in bytes.
OutputBufferSize int32
}
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
// The pipe must not already exist.
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
var (
sd []byte
err error
)
if c == nil {
c = &PipeConfig{}
}
if c.SecurityDescriptor != "" {
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
if err != nil {
return nil, err
}
}
h, err := makeServerPipeHandle(path, sd, c, true)
if err != nil {
return nil, err
}
l := &win32PipeListener{
firstHandle: h,
path: path,
config: *c,
acceptCh: make(chan (chan acceptResponse)),
closeCh: make(chan int),
doneCh: make(chan int),
}
go l.listenerRoutine()
return l, nil
}
func connectPipe(p *win32File) error {
c, err := p.prepareIo()
if err != nil {
return err
}
defer p.wg.Done()
err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED {
return err
}
return nil
}
func (l *win32PipeListener) Accept() (net.Conn, error) {
ch := make(chan acceptResponse)
select {
case l.acceptCh <- ch:
response := <-ch
err := response.err
if err != nil {
return nil, err
}
if l.config.MessageMode {
return &win32MessageBytePipe{
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
}, nil
}
return &win32Pipe{win32File: response.f, path: l.path}, nil
case <-l.doneCh:
return nil, ErrPipeListenerClosed
}
}
func (l *win32PipeListener) Close() error {
select {
case l.closeCh <- 1:
<-l.doneCh
case <-l.doneCh:
}
return nil
}
func (l *win32PipeListener) Addr() net.Addr {
return pipeAddress(l.path)
}

29
ipc/winpipe/sd.go Normal file
View File

@@ -0,0 +1,29 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"unsafe"
)
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
if err != nil {
return nil, err
}
defer localFree(sdBuffer)
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
}

View File

@@ -0,0 +1,274 @@
// Code generated by 'go generate'; DO NOT EDIT.
package winpipe
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procCreateFileW = modkernel32.NewProc("CreateFileW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
procLocalFree = modkernel32.NewProc("LocalFree")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
)
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
}
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name)
if err != nil {
return
}
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
}
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0)
if handle == syscall.InvalidHandle {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
ptr = uintptr(r0)
return
}
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0)
return
}
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
if r0 != 0 {
winerr = syscall.Errno(r0)
}
return
}
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
status = ntstatus(r0)
return
}
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
status = ntstatus(r0)
return
}
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(str)
if err != nil {
return
}
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
}
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0)
if newport == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32
if wait {
_p0 = 1
} else {
_p0 = 0
}
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}

49
main.go
View File

@@ -9,14 +9,15 @@ package main
import ( import (
"fmt" "fmt"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall" "syscall"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
) )
const ( const (
@@ -36,37 +37,27 @@ func printUsage() {
} }
func warning() { func warning() {
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return return
} }
shouldQuit := os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
shouldQuit := false
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G") fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
fmt.Fprintln(os.Stderr, "W do what it is supposed to do, and things may go G") fmt.Fprintln(os.Stderr, "W which is probably unnecessary and foolish. This G")
fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G") fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
fmt.Fprintln(os.Stderr, "W at your own risk. G") fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
if runtime.GOOS == "linux" { fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1" fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit {
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G") fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G") fmt.Fprintln(os.Stderr, "W the advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W foolish. This is because the Linux kernel has G") fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W built-in first class support for WireGuard, and G") fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
fmt.Fprintln(os.Stderr, "W this support is much more refined than this G")
fmt.Fprintln(os.Stderr, "W program. For more information on installing the G")
fmt.Fprintln(os.Stderr, "W kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit {
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
}
} }
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
@@ -134,7 +125,7 @@ func main() {
// open TUN device (or use supplied fd) // open TUN device (or use supplied fd)
tun, err := func() (tun.TUNDevice, error) { tun, err := func() (tun.Device, error) {
tunFdStr := os.Getenv(ENV_WG_TUN_FD) tunFdStr := os.Getenv(ENV_WG_TUN_FD)
if tunFdStr == "" { if tunFdStr == "" {
return tun.CreateTUN(interfaceName, device.DefaultMTU) return tun.CreateTUN(interfaceName, device.DefaultMTU)

View File

@@ -7,12 +7,13 @@ package main
import ( import (
"fmt" "fmt"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -27,11 +28,13 @@ func main() {
} }
interfaceName := os.Args[1] interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
logger := device.NewLogger( logger := device.NewLogger(
device.LogLevelDebug, device.LogLevelDebug,
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Info.Println("Starting wireguard-go version", WireGuardGoVersion) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName) tun, err := tun.CreateTUN(interfaceName)

View File

@@ -76,7 +76,7 @@ func (rate *Ratelimiter) Init() {
for key, entry := range rate.tableIPv4 { for key, entry := range rate.tableIPv4 {
entry.Lock() entry.Lock()
if time.Now().Sub(entry.lastTime) > garbageCollectTime { if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv4, key) delete(rate.tableIPv4, key)
} }
entry.Unlock() entry.Unlock()
@@ -84,7 +84,7 @@ func (rate *Ratelimiter) Init() {
for key, entry := range rate.tableIPv6 { for key, entry := range rate.tableIPv6 {
entry.Lock() entry.Lock()
if time.Now().Sub(entry.lastTime) > garbageCollectTime { if time.Since(entry.lastTime) > garbageCollectTime {
delete(rate.tableIPv6, key) delete(rate.tableIPv6, key)
} }
entry.Unlock() entry.Unlock()

View File

@@ -7,9 +7,10 @@ package rwcancel
import ( import (
"errors" "errors"
"golang.org/x/sys/unix"
"os" "os"
"syscall" "syscall"
"golang.org/x/sys/unix"
) )
func max(a, b int) int { func max(a, b int) int {

View File

@@ -15,11 +15,15 @@ import (
*/ */
func TestMonotonic(t *testing.T) { func TestMonotonic(t *testing.T) {
old := Now() old := Now()
for i := 0; i < 10000; i++ { for i := 0; i < 50; i++ {
time.Sleep(time.Nanosecond)
next := Now() next := Now()
if next.After(old) {
t.Error("Whitening insufficient")
}
time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
next = Now()
if !next.After(old) { if !next.After(old) {
t.Error("TAI64N, not monotonically increasing on nano-second scale") t.Error("Not monotonically increasing on whitened nano-second scale")
} }
old = next old = next
} }

View File

@@ -1,92 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package tun
import (
"bytes"
"errors"
"golang.zx2c4.com/wireguard/tun"
"os"
"testing"
)
/* Helpers for writing unit tests
*/
type DummyTUN struct {
name string
mtu int
packets chan []byte
events chan tun.TUNEvent
}
func (tun *DummyTUN) File() *os.File {
return nil
}
func (tun *DummyTUN) Name() (string, error) {
return tun.name, nil
}
func (tun *DummyTUN) MTU() (int, error) {
return tun.mtu, nil
}
func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
tun.packets <- d[offset:]
return len(d), nil
}
func (tun *DummyTUN) Close() error {
close(tun.events)
close(tun.packets)
return nil
}
func (tun *DummyTUN) Events() chan tun.TUNEvent {
return tun.events
}
func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
t, ok := <-tun.packets
if !ok {
return 0, errors.New("device closed")
}
copy(d[offset:], t)
return len(t), nil
}
func CreateDummyTUN(name string) (tun.TUNDevice, error) {
var dummy DummyTUN
dummy.mtu = 0
dummy.packets = make(chan []byte, 100)
dummy.events = make(chan tun.TUNEvent, 10)
return &dummy, nil
}
func assertNil(t *testing.T, err error) {
if err != nil {
t.Fatal(err)
}
}
func assertEqual(t *testing.T, a []byte, b []byte) {
if bytes.Compare(a, b) != 0 {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun, _ := CreateDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View File

@@ -9,21 +9,21 @@ import (
"os" "os"
) )
type TUNEvent int type Event int
const ( const (
TUNEventUp = 1 << iota EventUp = 1 << iota
TUNEventDown EventDown
TUNEventMTUUpdate EventMTUUpdate
) )
type TUNDevice interface { type Device interface {
File() *os.File // returns the file descriptor of the device File() *os.File // returns the file descriptor of the device
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers) Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers) Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
Flush() error // flush all previous writes to the device Flush() error // flush all previous writes to the device
MTU() (int, error) // returns the MTU of the device MTU() (int, error) // returns the MTU of the device
Name() (string, error) // fetches and returns the current name Name() (string, error) // fetches and returns the current name
Events() chan TUNEvent // returns a constant channel of events related to the device Events() chan Event // returns a constant channel of events related to the device
Close() error // stops the device and closes the event channel Close() error // stops the device and closes the event channel
} }

View File

@@ -7,13 +7,14 @@ package tun
import ( import (
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
const utunControlName = "com.apple.net.utun_control" const utunControlName = "com.apple.net.utun_control"
@@ -34,7 +35,7 @@ type sockaddrCtl struct {
type NativeTun struct { type NativeTun struct {
name string name string
tunFile *os.File tunFile *os.File
events chan TUNEvent events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
} }
@@ -82,22 +83,22 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
// Up / Down event // Up / Down event
up := (iface.Flags & net.FlagUp) != 0 up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up { if up != statusUp && up {
tun.events <- TUNEventUp tun.events <- EventUp
} }
if up != statusUp && !up { if up != statusUp && !up {
tun.events <- TUNEventDown tun.events <- EventDown
} }
statusUp = up statusUp = up
// MTU changes // MTU changes
if iface.MTU != statusMTU { if iface.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate tun.events <- EventMTUUpdate
} }
statusMTU = iface.MTU statusMTU = iface.MTU
} }
} }
func CreateTUN(name string, mtu int) (TUNDevice, error) { func CreateTUN(name string, mtu int) (Device, error) {
ifIndex := -1 ifIndex := -1
if name != "utun" { if name != "utun" {
_, err := fmt.Sscanf(name, "utun%d", &ifIndex) _, err := fmt.Sscanf(name, "utun%d", &ifIndex)
@@ -167,10 +168,10 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
return tun, err return tun, err
} }
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan TUNEvent, 10), events: make(chan Event, 10),
errors: make(chan error, 5), errors: make(chan error, 5),
} }
@@ -240,7 +241,7 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile return tun.tunFile
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *NativeTun) Events() chan Event {
return tun.events return tun.events
} }
@@ -282,7 +283,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg? // TODO: can flushing be implemented by buffering and using sendmmsg?
return nil return nil
} }
@@ -292,7 +293,6 @@ func (tun *NativeTun) Close() error {
if tun.routeSocket != -1 { if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err2 = unix.Close(tun.routeSocket) err2 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil { } else if tun.events != nil {
close(tun.events) close(tun.events)
} }

View File

@@ -9,19 +9,30 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h // _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96)) // const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const _TUNSIFHEAD = 0x80047460 const (
const _TUNSIFMODE = 0x8004745e _TUNSIFHEAD = 0x80047460
const _TUNSIFPID = 0x2000745f _TUNSIFMODE = 0x8004745e
_TUNSIFPID = 0x2000745f
)
// TODO: move into x/sys/unix
const (
SIOCGIFINFO_IN6 = 0xc048696c
SIOCSIFINFO_IN6 = 0xc048696d
ND6_IFF_AUTO_LINKLOCAL = 0x20
ND6_IFF_NO_DAD = 0x100
)
// Iface status string max len // Iface status string max len
const _IFSTATMAX = 800 const _IFSTATMAX = 800
@@ -32,7 +43,7 @@ const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
type ifreq_ptr struct { type ifreq_ptr struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
Pad0 [24 - SIZEOF_UINTPTR]byte Pad0 [16 - SIZEOF_UINTPTR]byte
} }
// Structure for iface mtu get/set ioctls // Structure for iface mtu get/set ioctls
@@ -48,10 +59,27 @@ type ifstat struct {
Ascii [_IFSTATMAX]byte Ascii [_IFSTATMAX]byte
} }
// Structures for nd6 flag manipulation
type in6_ndireq struct {
Name [unix.IFNAMSIZ]byte
Linkmtu uint32
Maxmtu uint32
Basereachable uint32
Reachable uint32
Retrans uint32
Flags uint32
Recalctm int
Chlim uint8
Initialized uint8
Randomseed0 [8]byte
Randomseed1 [8]byte
Randomid [8]byte
}
type NativeTun struct { type NativeTun struct {
name string name string
tunFile *os.File tunFile *os.File
events chan TUNEvent events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
} }
@@ -97,16 +125,16 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
// Up / Down event // Up / Down event
up := (iface.Flags & net.FlagUp) != 0 up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up { if up != statusUp && up {
tun.events <- TUNEventUp tun.events <- EventUp
} }
if up != statusUp && !up { if up != statusUp && !up {
tun.events <- TUNEventDown tun.events <- EventDown
} }
statusUp = up statusUp = up
// MTU changes // MTU changes
if iface.MTU != statusMTU { if iface.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate tun.events <- EventMTUUpdate
} }
statusMTU = iface.MTU statusMTU = iface.MTU
} }
@@ -191,23 +219,18 @@ func tunName(fd uintptr) (string, error) {
// Destroy a named system interface // Destroy a named system interface
func tunDestroy(name string) error { func tunDestroy(name string) error {
// open control socket // Open control socket.
var fd int var fd int
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
0, 0,
) )
if err != nil { if err != nil {
return err return err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call
var ifr [32]byte var ifr [32]byte
copy(ifr[:], name) copy(ifr[:], name)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
@@ -216,7 +239,6 @@ func tunDestroy(name string) error {
uintptr(unix.SIOCIFDESTROY), uintptr(unix.SIOCIFDESTROY),
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error())
} }
@@ -224,7 +246,7 @@ func tunDestroy(name string) error {
return nil return nil
} }
func CreateTUN(name string, mtu int) (TUNDevice, error) { func CreateTUN(name string, mtu int) (Device, error) {
if len(name) > unix.IFNAMSIZ-1 { if len(name) > unix.IFNAMSIZ-1 {
return nil, errors.New("interface name too long") return nil, errors.New("interface name too long")
} }
@@ -263,33 +285,71 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
}) })
if errno != 0 { if errno != 0 {
return nil, fmt.Errorf("error %s", errno.Error()) tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
} }
// Rename tun interface // Open control sockets
// Open control socket
confd, err := unix.Socket( confd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
0, 0,
) )
if err != nil { if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err return nil, err
} }
defer unix.Close(confd) defer unix.Close(confd)
confd6, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd6)
// set up struct for iface rename // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
var ndireq in6_ndireq
copy(ndireq.Name[:], assignedName)
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCGIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
}
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCSIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
}
// Rename the interface
var newnp [unix.IFNAMSIZ]byte var newnp [unix.IFNAMSIZ]byte
copy(newnp[:], name) copy(newnp[:], name)
var ifr ifreq_ptr var ifr ifreq_ptr
copy(ifr.Name[:], assignedName) copy(ifr.Name[:], assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
//do actual ioctl to rename iface
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
uintptr(confd), uintptr(confd),
@@ -298,18 +358,18 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
) )
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(name) tunDestroy(assignedName)
return nil, fmt.Errorf("failed to rename %s to %s: %s", assignedName, name, errno.Error()) return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
} }
return CreateTUNFromFile(tunFile, mtu) return CreateTUNFromFile(tunFile, mtu)
} }
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan TUNEvent, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
} }
@@ -365,7 +425,7 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile return tun.tunFile
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *NativeTun) Events() chan Event {
return tun.events return tun.events
} }
@@ -407,7 +467,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg? // TODO: can flushing be implemented by buffering and using sendmmsg?
return nil return nil
} }

View File

@@ -12,15 +12,16 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"os" "os"
"sync" "sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
const ( const (
@@ -30,11 +31,11 @@ const (
type NativeTun struct { type NativeTun struct {
tunFile *os.File tunFile *os.File
index int32 // if index index int32 // if index
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan TUNEvent // device related events events chan Event // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was pased IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
@@ -63,9 +64,9 @@ func (tun *NativeTun) routineHackListener() {
} }
switch err { switch err {
case unix.EINVAL: case unix.EINVAL:
tun.events <- TUNEventUp tun.events <- EventUp
case unix.EIO: case unix.EIO:
tun.events <- TUNEventDown tun.events <- EventDown
default: default:
return return
} }
@@ -147,14 +148,14 @@ func (tun *NativeTun) routineNetlinkListener() {
} }
if info.Flags&unix.IFF_RUNNING != 0 { if info.Flags&unix.IFF_RUNNING != 0 {
tun.events <- TUNEventUp tun.events <- EventUp
} }
if info.Flags&unix.IFF_RUNNING == 0 { if info.Flags&unix.IFF_RUNNING == 0 {
tun.events <- TUNEventDown tun.events <- EventDown
} }
tun.events <- TUNEventMTUUpdate tun.events <- EventMTUUpdate
default: default:
remain = remain[hdr.Len:] remain = remain[hdr.Len:]
@@ -319,7 +320,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg? // TODO: can flushing be implemented by buffering and using sendmmsg?
return nil return nil
} }
@@ -341,7 +342,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *NativeTun) Events() chan Event {
return tun.events return tun.events
} }
@@ -363,7 +364,7 @@ func (tun *NativeTun) Close() error {
return err2 return err2
} }
func CreateTUN(name string, mtu int) (TUNDevice, error) { func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0) nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -399,10 +400,10 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
return CreateTUNFromFile(fd, mtu) return CreateTUNFromFile(fd, mtu)
} }
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan TUNEvent, 5), events: make(chan Event, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}), statusListenersShutdown: make(chan struct{}),
nopi: false, nopi: false,
@@ -444,7 +445,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
return tun, nil return tun, nil
} }
func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) { func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true) err := unix.SetNonblock(fd, true)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
@@ -452,7 +453,7 @@ func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) {
file := os.NewFile(uintptr(fd), "/dev/tun") file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan TUNEvent, 5), events: make(chan Event, 5),
errors: make(chan error, 5), errors: make(chan error, 5),
nopi: true, nopi: true,
} }

View File

@@ -7,13 +7,14 @@ package tun
import ( import (
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
// Structure for iface mtu get/set ioctls // Structure for iface mtu get/set ioctls
@@ -28,7 +29,7 @@ const _TUNSIFMODE = 0x8004745d
type NativeTun struct { type NativeTun struct {
name string name string
tunFile *os.File tunFile *os.File
events chan TUNEvent events chan Event
errors chan error errors chan error
routeSocket int routeSocket int
} }
@@ -74,16 +75,16 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
// Up / Down event // Up / Down event
up := (iface.Flags & net.FlagUp) != 0 up := (iface.Flags & net.FlagUp) != 0
if up != statusUp && up { if up != statusUp && up {
tun.events <- TUNEventUp tun.events <- EventUp
} }
if up != statusUp && !up { if up != statusUp && !up {
tun.events <- TUNEventDown tun.events <- EventDown
} }
statusUp = up statusUp = up
// MTU changes // MTU changes
if iface.MTU != statusMTU { if iface.MTU != statusMTU {
tun.events <- TUNEventMTUUpdate tun.events <- EventMTUUpdate
} }
statusMTU = iface.MTU statusMTU = iface.MTU
} }
@@ -99,7 +100,7 @@ func errorIsEBUSY(err error) bool {
return false return false
} }
func CreateTUN(name string, mtu int) (TUNDevice, error) { func CreateTUN(name string, mtu int) (Device, error) {
ifIndex := -1 ifIndex := -1
if name != "tun" { if name != "tun" {
_, err := fmt.Sscanf(name, "tun%d", &ifIndex) _, err := fmt.Sscanf(name, "tun%d", &ifIndex)
@@ -138,11 +139,11 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
return tun, err return tun, err
} }
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan TUNEvent, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
} }
@@ -172,10 +173,13 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
go tun.routineRouteListener(tunIfindex) go tun.routineRouteListener(tunIfindex)
err = tun.setMTU(mtu) currentMTU, err := tun.MTU()
if err != nil { if err != nil || currentMTU != mtu {
tun.Close() err = tun.setMTU(mtu)
return nil, err if err != nil {
tun.Close()
return nil, err
}
} }
return tun, nil return tun, nil
@@ -196,7 +200,7 @@ func (tun *NativeTun) File() *os.File {
return tun.tunFile return tun.tunFile
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *NativeTun) Events() chan Event {
return tun.events return tun.events
} }
@@ -238,7 +242,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
//TODO: can flushing be implemented by buffering and using sendmmsg? // TODO: can flushing be implemented by buffering and using sendmmsg?
return nil return nil
} }

View File

@@ -7,367 +7,288 @@ package tun
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"sync" "sync/atomic"
"syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun" "golang.zx2c4.com/wireguard/tun/wintun"
) )
const ( const (
packetExchangeMax uint32 = 256 // Number of packets that may be written at a time packetAlignment uint32 = 4 // Number of bytes packets are aligned to in rings
packetExchangeAlignment uint32 = 16 // Number of bytes packets are aligned to in exchange buffers packetSizeMax = 0xffff // Maximum packet size
packetSizeMax uint32 = 0xf000 - packetExchangeAlignment // Maximum packet size packetCapacity = 0x800000 // Ring capacity, 8MiB
packetExchangeSize uint32 = 0x100000 // Exchange buffer size (defaults to 1MiB) packetTrailingSize = uint32(unsafe.Sizeof(packetHeader{})) + ((packetSizeMax + (packetAlignment - 1)) &^ (packetAlignment - 1)) - packetAlignment
retryRate = 4 // Number of retries per second to reopen device pipe ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
retryTimeout = 30 // Number of seconds to tolerate adapter unavailable
) )
type exchgBufRead struct { type packetHeader struct {
data [packetExchangeSize]byte size uint32
offset uint32
avail uint32
} }
type exchgBufWrite struct { type packet struct {
data [packetExchangeSize]byte packetHeader
offset uint32 data [packetSizeMax]byte
packetNum uint32 }
type ring struct {
head uint32
tail uint32
alertable int32
data [packetCapacity + packetTrailingSize]byte
}
type ringDescriptor struct {
send, receive struct {
size uint32
ring *ring
tailMoved windows.Handle
}
} }
type NativeTun struct { type NativeTun struct {
wt *wintun.Wintun wt *wintun.Wintun
tunFileRead *os.File handle windows.Handle
tunFileWrite *os.File close bool
tunLock sync.Mutex rings ringDescriptor
close bool events chan Event
rdBuff *exchgBufRead errors chan error
wrBuff *exchgBufWrite forcedMTU int
events chan TUNEvent
errors chan error
forcedMtu int
} }
func packetAlign(size uint32) uint32 { func packetAlign(size uint32) uint32 {
return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1) return (size + (packetAlignment - 1)) &^ (packetAlignment - 1)
} }
// //
// CreateTUN creates a Wintun adapter with the given name. Should a Wintun // CreateTUN creates a Wintun adapter with the given name. Should a Wintun
// adapter with the same name exist, it is reused. // adapter with the same name exist, it is reused.
// //
func CreateTUN(ifname string) (TUNDevice, error) { func CreateTUN(ifname string) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil)
}
//
// CreateTUNWithRequestedGUID creates a Wintun adapter with the given name and
// a requested GUID. Should a Wintun adapter with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) {
var err error
var wt *wintun.Wintun
// Does an interface with this name already exist? // Does an interface with this name already exist?
wt, err := wintun.GetInterface(ifname, 0) wt, err = wintun.GetInterface(ifname)
if wt == nil { if err == nil {
// Interface does not exist or an error occured. Create one. // If so, we delete it, in case it has weird residual configuration.
wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", 0) _, err = wt.DeleteInterface()
if err != nil { if err != nil {
return nil, errors.New("Creating Wintun adapter failed: " + err.Error()) return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err)
} }
} else if err != nil { } else if err == windows.ERROR_ALREADY_EXISTS {
// Foreign interface with the same name found. return nil, fmt.Errorf("Foreign network interface with the same name exists")
// We could create a Wintun interface under a temporary name. But, should our }
// proces die without deleting this interface first, the interface would remain wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", requestedGUID)
// orphaned. if err != nil {
return nil, err return nil, fmt.Errorf("Unable to create Wintun interface: %v", err)
} }
err = wt.SetInterfaceName(ifname) err = wt.SetInterfaceName(ifname)
if err != nil { if err != nil {
wt.DeleteInterface(0) wt.DeleteInterface()
return nil, fmt.Errorf("Unable to set name of Wintun interface: %v", err)
}
tun := &NativeTun{
wt: wt,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
errors: make(chan error, 1),
forcedMTU: 1500,
}
tun.rings.send.size = uint32(unsafe.Sizeof(ring{}))
tun.rings.send.ring = &ring{}
tun.rings.send.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating event: %v", err)
}
tun.rings.receive.size = uint32(unsafe.Sizeof(ring{}))
tun.rings.receive.ring = &ring{}
tun.rings.receive.tailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating event: %v", err)
}
tun.handle, err = tun.wt.AdapterHandle()
if err != nil {
tun.Close()
return nil, err return nil, err
} }
err = wt.FlushInterface() var bytesReturned uint32
err = windows.DeviceIoControl(tun.handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(&tun.rings)), uint32(unsafe.Sizeof(tun.rings)), nil, 0, &bytesReturned, nil)
if err != nil { if err != nil {
wt.DeleteInterface(0) tun.Close()
return nil, errors.New("Flushing interface failed: " + err.Error()) return nil, fmt.Errorf("Error registering rings: %v", err)
} }
return tun, nil
return &NativeTun{
wt: wt,
rdBuff: &exchgBufRead{},
wrBuff: &exchgBufWrite{},
events: make(chan TUNEvent, 10),
errors: make(chan error, 1),
forcedMtu: 1500,
}, nil
}
func (tun *NativeTun) openTUN() error {
retries := retryTimeout * retryRate
if tun.close {
return os.ErrClosed
}
var err error
name := tun.wt.DataFileName()
for tun.tunFileRead == nil {
tun.tunFileRead, err = os.OpenFile(name, os.O_RDONLY, 0)
if err != nil {
if retries > 0 && !tun.close {
time.Sleep(time.Second / retryRate)
retries--
continue
}
return err
}
}
for tun.tunFileWrite == nil {
tun.tunFileWrite, err = os.OpenFile(name, os.O_WRONLY, 0)
if err != nil {
if retries > 0 && !tun.close {
time.Sleep(time.Second / retryRate)
retries--
continue
}
return err
}
}
return nil
}
func (tun *NativeTun) closeTUN() (err error) {
for tun.tunFileRead != nil {
tun.tunLock.Lock()
if tun.tunFileRead == nil {
tun.tunLock.Unlock()
break
}
t := tun.tunFileRead
tun.tunFileRead = nil
windows.CancelIoEx(windows.Handle(t.Fd()), nil)
err = t.Close()
tun.tunLock.Unlock()
break
}
for tun.tunFileWrite != nil {
tun.tunLock.Lock()
if tun.tunFileWrite == nil {
tun.tunLock.Unlock()
break
}
t := tun.tunFileWrite
tun.tunFileWrite = nil
windows.CancelIoEx(windows.Handle(t.Fd()), nil)
err2 := t.Close()
tun.tunLock.Unlock()
if err == nil {
err = err2
}
break
}
return
}
func (tun *NativeTun) getTUN() (read *os.File, write *os.File, err error) {
read, write = tun.tunFileRead, tun.tunFileWrite
if read == nil || write == nil {
read, write = nil, nil
tun.tunLock.Lock()
if tun.tunFileRead != nil && tun.tunFileWrite != nil {
read, write = tun.tunFileRead, tun.tunFileWrite
tun.tunLock.Unlock()
return
}
err = tun.closeTUN()
if err != nil {
tun.tunLock.Unlock()
return
}
err = tun.openTUN()
if err == nil {
read, write = tun.tunFileRead, tun.tunFileWrite
}
tun.tunLock.Unlock()
return
}
return
} }
func (tun *NativeTun) Name() (string, error) { func (tun *NativeTun) Name() (string, error) {
return tun.wt.GetInterfaceName() return tun.wt.InterfaceName()
} }
func (tun *NativeTun) File() *os.File { func (tun *NativeTun) File() *os.File {
return nil return nil
} }
func (tun *NativeTun) Events() chan TUNEvent { func (tun *NativeTun) Events() chan Event {
return tun.events return tun.events
} }
func (tun *NativeTun) Close() error { func (tun *NativeTun) Close() error {
tun.close = true tun.close = true
err1 := tun.closeTUN() if tun.rings.send.tailMoved != 0 {
windows.SetEvent(tun.rings.send.tailMoved) // wake the reader if it's sleeping
if tun.events != nil {
close(tun.events)
} }
if tun.handle != windows.InvalidHandle {
_, _, err2 := tun.wt.DeleteInterface(0) windows.CloseHandle(tun.handle)
if err1 == nil {
err1 = err2
} }
if tun.rings.send.tailMoved != 0 {
return err1 windows.CloseHandle(tun.rings.send.tailMoved)
}
if tun.rings.send.tailMoved != 0 {
windows.CloseHandle(tun.rings.receive.tailMoved)
}
var err error
if tun.wt != nil {
_, err = tun.wt.DeleteInterface()
}
close(tun.events)
return err
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *NativeTun) MTU() (int, error) {
return tun.forcedMtu, nil return tun.forcedMTU, nil
} }
//TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. // TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMtu(mtu int) { func (tun *NativeTun) ForceMTU(mtu int) {
tun.forcedMtu = mtu tun.forcedMTU = mtu
} }
//go:linkname procyield runtime.procyield
func procyield(cycles uint32)
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
retry:
select { select {
case err := <-tun.errors: case err := <-tun.errors:
return 0, err return 0, err
default: default:
} }
if tun.close {
return 0, os.ErrClosed
}
buffHead := atomic.LoadUint32(&tun.rings.send.ring.head)
if buffHead >= packetCapacity {
return 0, os.ErrClosed
}
start := time.Now()
var buffTail uint32
for { for {
if tun.rdBuff.offset+packetExchangeAlignment <= tun.rdBuff.avail { buffTail = atomic.LoadUint32(&tun.rings.send.ring.tail)
// Get packet from the exchange buffer. if buffHead != buffTail {
packet := tun.rdBuff.data[tun.rdBuff.offset:]
size := *(*uint32)(unsafe.Pointer(&packet[0]))
pSize := packetAlign(packetExchangeAlignment + size)
if packetSizeMax < size || tun.rdBuff.avail < tun.rdBuff.offset+pSize {
// Invalid packet size.
tun.rdBuff.avail = 0
continue
}
packet = packet[packetExchangeAlignment : packetExchangeAlignment+size]
// Copy data.
copy(buff[offset:], packet)
tun.rdBuff.offset += pSize
return int(size), nil
}
// Get TUN data pipe.
file, _, err := tun.getTUN()
if err != nil {
return 0, err
}
// Fill queue.
retries := 1000
for {
n, err := file.Read(tun.rdBuff.data[:])
if err != nil {
tun.rdBuff.offset = 0
tun.rdBuff.avail = 0
pe, ok := err.(*os.PathError)
if tun.close {
return 0, os.ErrClosed
}
if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED {
retries--
continue
}
if ok && pe.Err == syscall.Errno(6) /*windows.ERROR_INVALID_HANDLE*/ {
tun.closeTUN()
break
}
return 0, err
}
tun.rdBuff.offset = 0
tun.rdBuff.avail = uint32(n)
break break
} }
if tun.close {
return 0, os.ErrClosed
}
if time.Since(start) >= time.Millisecond/80 /* ~1gbit/s */ {
windows.WaitForSingleObject(tun.rings.send.tailMoved, windows.INFINITE)
goto retry
}
procyield(1)
}
if buffTail >= packetCapacity {
return 0, os.ErrClosed
} }
}
// Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking. buffContent := tun.rings.send.ring.wrap(buffTail - buffHead)
if buffContent < uint32(unsafe.Sizeof(packetHeader{})) {
return 0, errors.New("incomplete packet header in send ring")
}
packet := (*packet)(unsafe.Pointer(&tun.rings.send.ring.data[buffHead]))
if packet.size > packetSizeMax {
return 0, errors.New("packet too big in send ring")
}
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packet.size)
if alignedPacketSize > buffContent {
return 0, errors.New("incomplete packet in send ring")
}
copy(buff[offset:], packet.data[:packet.size])
buffHead = tun.rings.send.ring.wrap(buffHead + alignedPacketSize)
atomic.StoreUint32(&tun.rings.send.ring.head, buffHead)
return int(packet.size), nil
}
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
if tun.wrBuff.offset == 0 {
return nil
}
for {
// Get TUN data pipe.
_, file, err := tun.getTUN()
if err != nil {
return err
}
// Flush write buffer.
retries := retryTimeout * retryRate
for {
_, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset])
tun.wrBuff.packetNum = 0
tun.wrBuff.offset = 0
if err != nil {
pe, ok := err.(*os.PathError)
if tun.close {
return os.ErrClosed
}
if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED {
time.Sleep(time.Second / retryRate)
retries--
continue
}
if ok && pe.Err == syscall.Errno(6) /*windows.ERROR_INVALID_HANDLE*/ {
tun.closeTUN()
break
}
return err
}
return nil
}
}
}
func (tun *NativeTun) putTunPacket(buff []byte) error {
size := uint32(len(buff))
if size == 0 {
return errors.New("Empty packet")
}
if size > packetSizeMax {
return errors.New("Packet too big")
}
pSize := packetAlign(packetExchangeAlignment + size)
if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize {
// Exchange buffer is full -> flush first.
err := tun.Flush()
if err != nil {
return err
}
}
// Write packet to the exchange buffer.
packet := tun.wrBuff.data[tun.wrBuff.offset : tun.wrBuff.offset+pSize]
*(*uint32)(unsafe.Pointer(&packet[0])) = size
packet = packet[packetExchangeAlignment : packetExchangeAlignment+size]
copy(packet, buff)
tun.wrBuff.packetNum++
tun.wrBuff.offset += pSize
return nil return nil
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
err := tun.putTunPacket(buff[offset:]) if tun.close {
if err != nil { return 0, os.ErrClosed
return 0, err
} }
return len(buff) - offset, nil
packetSize := uint32(len(buff) - offset)
alignedPacketSize := packetAlign(uint32(unsafe.Sizeof(packetHeader{})) + packetSize)
buffHead := atomic.LoadUint32(&tun.rings.receive.ring.head)
if buffHead >= packetCapacity {
return 0, os.ErrClosed
}
buffTail := atomic.LoadUint32(&tun.rings.receive.ring.tail)
if buffTail >= packetCapacity {
return 0, os.ErrClosed
}
buffSpace := tun.rings.receive.ring.wrap(buffHead - buffTail - packetAlignment)
if alignedPacketSize > buffSpace {
return 0, nil // Dropping when ring is full.
}
packet := (*packet)(unsafe.Pointer(&tun.rings.receive.ring.data[buffTail]))
packet.size = packetSize
copy(packet.data[:packetSize], buff[offset:])
atomic.StoreUint32(&tun.rings.receive.ring.tail, tun.rings.receive.ring.wrap(buffTail+alignedPacketSize))
if atomic.LoadInt32(&tun.rings.receive.ring.alertable) != 0 {
windows.SetEvent(tun.rings.receive.tailMoved)
}
return int(packetSize), nil
} }
// // LUID returns Windows adapter instance ID.
// GUID returns Windows adapter instance ID. func (tun *NativeTun) LUID() uint64 {
// return tun.wt.LUID()
func (tun *NativeTun) GUID() windows.GUID { }
return tun.wt.CfgInstanceID
// wrap returns value modulo ring capacity
func (rb *ring) wrap(value uint32) uint32 {
return value & (packetCapacity - 1)
} }

View File

@@ -1,44 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package guid
import (
"fmt"
"syscall"
"golang.org/x/sys/windows"
)
//sys clsidFromString(lpsz *uint16, pclsid *windows.GUID) (hr int32) = ole32.CLSIDFromString
//
// FromString parses "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}" string to GUID.
//
func FromString(str string) (*windows.GUID, error) {
strUTF16, err := syscall.UTF16PtrFromString(str)
if err != nil {
return nil, err
}
guid := &windows.GUID{}
hr := clsidFromString(strUTF16, guid)
if hr < 0 {
return nil, syscall.Errno(hr)
}
return guid, nil
}
//
// ToString function converts GUID to string
// "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}".
//
// The resulting string is uppercase.
//
func ToString(guid *windows.GUID) string {
return fmt.Sprintf("{%06X-%04X-%04X-%04X-%012X}", guid.Data1, guid.Data2, guid.Data3, guid.Data4[:2], guid.Data4[2:])
}

View File

@@ -13,7 +13,7 @@ import (
) )
var ( var (
modnetshell = windows.NewLazySystemDLL("netshell.dll") modnetshell = windows.NewLazySystemDLL("netshell.dll")
procHrRenameConnection = modnetshell.NewProc("HrRenameConnection") procHrRenameConnection = modnetshell.NewProc("HrRenameConnection")
) )

View File

@@ -3,6 +3,6 @@
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved. * Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/ */
package guid package registry
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zguid_windows.go guid_windows.go //go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go

View File

@@ -0,0 +1,272 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package registry
import (
"errors"
"fmt"
"runtime"
"strings"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
const (
// REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted.
REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001
// REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information.
REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002
// REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value.
REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004
// REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key.
REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008
// REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later.
REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000
)
//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
deadline := time.Now().Add(timeout)
pathSpl := strings.Split(path, "\\")
for i := 0; ; i++ {
keyName := pathSpl[i]
isLast := i+1 == len(pathSpl)
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, fmt.Errorf("Error creating event: %v", err)
}
defer windows.CloseHandle(event)
var key registry.Key
for {
err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true)
if err != nil {
return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err)
}
var accessFlags uint32
if isLast {
accessFlags = access
} else {
accessFlags = registry.NOTIFY
}
key, err = registry.OpenKey(k, keyName, accessFlags)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return 0, fmt.Errorf("Unable to wait on registry key: %v", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return 0, errors.New("Timeout waiting for registry key")
}
} else if err != nil {
return 0, fmt.Errorf("Error opening registry key %v: %v", path, err)
} else {
if isLast {
return key, nil
}
defer key.Close()
break
}
}
k = key
}
}
func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout)
if err != nil {
return err
}
key.Close()
return nil
}
//
// getValue is more or less the same as windows/registry's getValue.
//
func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
return
}
n := uint32(len(buf))
for {
err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n)
if err == nil {
value = buf[:n]
return
}
if err != windows.ERROR_MORE_DATA {
return
}
if n <= uint32(len(buf)) {
return
}
buf = make([]byte, n)
}
}
//
// getValueRetry function reads any value from registry. It waits for
// the registry value to become available or returns error on timeout.
//
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
//
func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, 0, fmt.Errorf("Error creating event: %v", err)
}
defer windows.CloseHandle(event)
deadline := time.Now().Add(timeout)
for {
err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
if err != nil {
return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
}
buf, valueType, err := getValue(key, name, buf)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return nil, 0, errors.New("Timeout waiting for registry value")
}
} else if err != nil {
return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
} else {
return buf, valueType, nil
}
}
}
func toString(buf []byte, valueType uint32, err error) (string, error) {
if err != nil {
return "", err
}
var value string
switch valueType {
case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ:
if len(buf) == 0 {
return "", nil
}
value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
default:
return "", registry.ErrUnexpectedType
}
if valueType != registry.EXPAND_SZ {
// Value does not require expansion.
return value, nil
}
valueExp, err := registry.ExpandString(value)
if err != nil {
// Expanding failed: return original sting value.
return value, nil
}
// Return expanded value.
return valueExp, nil
}
func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
if err != nil {
return 0, err
}
switch valueType {
case registry.DWORD:
if len(buf) != 4 {
return 0, errors.New("DWORD value is not 4 bytes long")
}
var val uint32
copy((*[4]byte)(unsafe.Pointer(&val))[:], buf)
return uint64(val), nil
case registry.QWORD:
if len(buf) != 8 {
return 0, errors.New("QWORD value is not 8 bytes long")
}
var val uint64
copy((*[8]byte)(unsafe.Pointer(&val))[:], buf)
return val, nil
default:
return 0, registry.ErrUnexpectedType
}
}
//
// GetStringValueWait function reads a string value from registry. It waits
// for the registry value to become available or returns error on timeout.
//
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
//
// If the value type is REG_EXPAND_SZ the environment variables are expanded.
// Should expanding fail, original string value and nil error are returned.
//
// If the value type is REG_MULTI_SZ only the first string is returned.
//
func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
return toString(getValueRetry(key, name, make([]byte, 256), timeout))
}
//
// GetStringValue function reads a string value from registry.
//
// Key must be opened with at least QUERY_VALUE access.
//
// If the value type is REG_EXPAND_SZ the environment variables are expanded.
// Should expanding fail, original string value and nil error are returned.
//
// If the value type is REG_MULTI_SZ only the first string is returned.
//
func GetStringValue(key registry.Key, name string) (string, error) {
return toString(getValue(key, name, make([]byte, 256)))
}
//
// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry.
// It waits for the registry value to become available or returns error on
// timeout.
//
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
//
func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) {
return toInteger(getValueRetry(key, name, make([]byte, 8), timeout))
}

View File

@@ -0,0 +1,103 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package registry
import (
"testing"
"time"
"golang.org/x/sys/windows/registry"
)
const keyRoot = registry.CURRENT_USER
const pathRoot = "Software\\WireGuardRegistryTest"
const path = pathRoot + "\\foobar"
const pathFake = pathRoot + "\\raboof"
func Test_WaitForKey(t *testing.T) {
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
go func() {
time.Sleep(time.Second * 1)
key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE)
if err != nil {
t.Errorf("Error creating registry key: %v", err)
}
key.Close()
registry.DeleteKey(keyRoot, pathFake)
key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE)
if err != nil {
t.Errorf("Error creating registry key: %v", err)
}
key.Close()
}()
err := WaitForKey(keyRoot, path, time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry key: %v", err)
}
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
err = WaitForKey(keyRoot, path, time.Second*1)
if err == nil {
t.Error("Registry key notification expected to timeout but it succeeded.")
}
}
func Test_GetValueWait(t *testing.T) {
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
go func() {
time.Sleep(time.Second * 1)
key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE)
if err != nil {
t.Errorf("Error creating registry key: %v", err)
}
time.Sleep(time.Second * 1)
key.SetStringValue("name1", "eulav")
key.SetExpandStringValue("name2", "value")
time.Sleep(time.Second * 1)
key.SetDWordValue("name3", ^uint32(123))
key.SetDWordValue("name4", 123)
key.Close()
}()
key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry key: %v", err)
}
valueStr, err := GetStringValueWait(key, "name2", time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry value: %v", err)
}
if valueStr != "value" {
t.Errorf("Wrong value read: %v", valueStr)
}
_, err = GetStringValueWait(key, "nonexisting", time.Second*1)
if err == nil {
t.Error("Registry value notification expected to timeout but it succeeded.")
}
valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry value: %v", err)
}
if valueInt != 123 {
t.Errorf("Wrong value read: %v", valueInt)
}
_, err = GetIntegerValueWait(key, "nonexisting", time.Second*1)
if err == nil {
t.Error("Registry value notification expected to timeout but it succeeded.")
}
key.Close()
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
}

View File

@@ -1,6 +1,6 @@
// Code generated by 'go generate'; DO NOT EDIT. // Code generated by 'go generate'; DO NOT EDIT.
package guid package registry
import ( import (
"syscall" "syscall"
@@ -37,13 +37,27 @@ func errnoErr(e syscall.Errno) error {
} }
var ( var (
modole32 = windows.NewLazySystemDLL("ole32.dll") modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
procCLSIDFromString = modole32.NewProc("CLSIDFromString") procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
) )
func clsidFromString(lpsz *uint16, pclsid *windows.GUID) (hr int32) { func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) {
r0, _, _ := syscall.Syscall(procCLSIDFromString.Addr(), 2, uintptr(unsafe.Pointer(lpsz)), uintptr(unsafe.Pointer(pclsid)), 0) var _p0 uint32
hr = int32(r0) if watchSubtree {
_p0 = 1
} else {
_p0 = 0
}
var _p1 uint32
if asynchronous {
_p1 = 1
} else {
_p1 = 0
}
r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return return
} }

View File

@@ -1,42 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"golang.org/x/sys/windows/registry"
"time"
)
const (
numRetries = 25
retryTimeout = 100 * time.Millisecond
)
func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) {
for i := 0; i < numRetries; i++ {
key, err = registry.OpenKey(k, path, access)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}
func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) {
for i := 0; i < numRetries; i++ {
val, valtype, err = k.GetStringValue(name)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}

View File

@@ -8,7 +8,7 @@ package setupapi
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"syscall" "runtime"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -21,7 +21,7 @@ import (
func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) { func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) {
var machineNameUTF16 *uint16 var machineNameUTF16 *uint16
if machineName != "" { if machineName != "" {
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName) machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
if err != nil { if err != nil {
return return
} }
@@ -34,13 +34,13 @@ func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr,
// SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. // SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) { func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) {
data := &DevInfoListDetailData{} data := &DevInfoListDetailData{}
data.size = uint32(unsafe.Sizeof(*data)) data.size = sizeofDevInfoListDetailData
return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data) return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data)
} }
// GetDeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. // DeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
func (deviceInfoSet DevInfo) GetDeviceInfoListDetail() (*DevInfoListDetailData, error) { func (deviceInfoSet DevInfo) DeviceInfoListDetail() (*DevInfoListDetailData, error) {
return SetupDiGetDeviceInfoListDetail(deviceInfoSet) return SetupDiGetDeviceInfoListDetail(deviceInfoSet)
} }
@@ -48,14 +48,14 @@ func (deviceInfoSet DevInfo) GetDeviceInfoListDetail() (*DevInfoListDetailData,
// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set. // SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set.
func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) { func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) {
deviceNameUTF16, err := syscall.UTF16PtrFromString(deviceName) deviceNameUTF16, err := windows.UTF16PtrFromString(deviceName)
if err != nil { if err != nil {
return return
} }
var deviceDescriptionUTF16 *uint16 var deviceDescriptionUTF16 *uint16
if deviceDescription != "" { if deviceDescription != "" {
deviceDescriptionUTF16, err = syscall.UTF16PtrFromString(deviceDescription) deviceDescriptionUTF16, err = windows.UTF16PtrFromString(deviceDescription)
if err != nil { if err != nil {
return return
} }
@@ -134,8 +134,8 @@ func SetupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData
return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data) return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data)
} }
// GetSelectedDriver method retrieves the selected driver for a device information set or a particular device information element. // SelectedDriver method retrieves the selected driver for a device information set or a particular device information element.
func (deviceInfoSet DevInfo) GetSelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) { func (deviceInfoSet DevInfo) SelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) {
return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData) return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData)
} }
@@ -150,38 +150,25 @@ func (deviceInfoSet DevInfo) SetSelectedDriver(deviceInfoData *DevInfoData, driv
// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set. // SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set.
func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) { func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
const bufCapacity = 0x800 reqSize := uint32(2048)
buf := [bufCapacity]byte{} for {
var bufLen uint32 buf := make([]byte, reqSize)
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0])) data.size = sizeofDrvInfoDetailData
data.size = uint32(unsafe.Sizeof(*data)) err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, uint32(len(buf)), &reqSize)
if err == windows.ERROR_INSUFFICIENT_BUFFER {
err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufCapacity, &bufLen) continue
if err == nil { }
// The buffer was was sufficiently big. if err != nil {
data.size = bufLen return nil, err
}
data.size = reqSize
return data, nil return data, nil
} }
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_INSUFFICIENT_BUFFER {
// The buffer was too small. Now that we got the required size, create another one big enough and retry.
buf := make([]byte, bufLen)
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
data.size = uint32(unsafe.Sizeof(*data))
err = setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufLen, &bufLen)
if err == nil {
data.size = bufLen
return data, nil
}
}
return nil, err
} }
// GetDriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set. // DriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set.
func (deviceInfoSet DevInfo) GetDriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) { func (deviceInfoSet DevInfo) DriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData) return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData)
} }
@@ -198,14 +185,14 @@ func (deviceInfoSet DevInfo) DestroyDriverInfoList(deviceInfoData *DevInfoData,
func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) { func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) {
var enumeratorUTF16 *uint16 var enumeratorUTF16 *uint16
if enumerator != "" { if enumerator != "" {
enumeratorUTF16, err = syscall.UTF16PtrFromString(enumerator) enumeratorUTF16, err = windows.UTF16PtrFromString(enumerator)
if err != nil { if err != nil {
return return
} }
} }
var machineNameUTF16 *uint16 var machineNameUTF16 *uint16
if machineName != "" { if machineName != "" {
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName) machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
if err != nil { if err != nil {
return return
} }
@@ -238,32 +225,31 @@ func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DI
// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property. // SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) { func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) {
buf := make([]byte, 0x100) reqSize := uint32(256)
var dataType, bufLen uint32 for {
err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(cap(buf)), &bufLen) var dataType uint32
if err == nil { buf := make([]byte, reqSize)
// The buffer was sufficiently big. err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(len(buf)), &reqSize)
return getRegistryValue(buf[:bufLen], dataType) if err == windows.ERROR_INSUFFICIENT_BUFFER {
} continue
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_INSUFFICIENT_BUFFER {
// The buffer was too small. Now that we got the required size, create another one big enough and retry.
buf = make([]byte, bufLen)
err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(cap(buf)), &bufLen)
if err == nil {
return getRegistryValue(buf[:bufLen], dataType)
} }
if err != nil {
return
}
return getRegistryValue(buf[:reqSize], dataType)
} }
return
} }
func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) { func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
switch dataType { switch dataType {
case windows.REG_SZ: case windows.REG_SZ:
return windows.UTF16ToString(BufToUTF16(buf)), nil ret := windows.UTF16ToString(bufToUTF16(buf))
runtime.KeepAlive(buf)
return ret, nil
case windows.REG_EXPAND_SZ: case windows.REG_EXPAND_SZ:
return registry.ExpandString(windows.UTF16ToString(BufToUTF16(buf))) ret, err := registry.ExpandString(windows.UTF16ToString(bufToUTF16(buf)))
runtime.KeepAlive(buf)
return ret, err
case windows.REG_BINARY: case windows.REG_BINARY:
return buf, nil return buf, nil
case windows.REG_DWORD_LITTLE_ENDIAN: case windows.REG_DWORD_LITTLE_ENDIAN:
@@ -271,7 +257,7 @@ func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
case windows.REG_DWORD_BIG_ENDIAN: case windows.REG_DWORD_BIG_ENDIAN:
return binary.BigEndian.Uint32(buf), nil return binary.BigEndian.Uint32(buf), nil
case windows.REG_MULTI_SZ: case windows.REG_MULTI_SZ:
bufW := BufToUTF16(buf) bufW := bufToUTF16(buf)
a := []string{} a := []string{}
for i := 0; i < len(bufW); { for i := 0; i < len(bufW); {
j := i + wcslen(bufW[i:]) j := i + wcslen(bufW[i:])
@@ -280,6 +266,7 @@ func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
} }
i = j + 1 i = j + 1
} }
runtime.KeepAlive(buf)
return a, nil return a, nil
case windows.REG_QWORD_LITTLE_ENDIAN: case windows.REG_QWORD_LITTLE_ENDIAN:
return binary.LittleEndian.Uint64(buf), nil return binary.LittleEndian.Uint64(buf), nil
@@ -288,8 +275,8 @@ func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
} }
} }
// BufToUTF16 function reinterprets []byte buffer as []uint16 // bufToUTF16 function reinterprets []byte buffer as []uint16
func BufToUTF16(buf []byte) []uint16 { func bufToUTF16(buf []byte) []uint16 {
sl := struct { sl := struct {
addr *uint16 addr *uint16
len int len int
@@ -298,8 +285,8 @@ func BufToUTF16(buf []byte) []uint16 {
return *(*[]uint16)(unsafe.Pointer(&sl)) return *(*[]uint16)(unsafe.Pointer(&sl))
} }
// UTF16ToBuf function reinterprets []uint16 as []byte // utf16ToBuf function reinterprets []uint16 as []byte
func UTF16ToBuf(buf []uint16) []byte { func utf16ToBuf(buf []uint16) []byte {
sl := struct { sl := struct {
addr *byte addr *byte
len int len int
@@ -317,8 +304,8 @@ func wcslen(str []uint16) int {
return len(str) return len(str)
} }
// GetDeviceRegistryProperty method retrieves a specified Plug and Play device property. // DeviceRegistryProperty method retrieves a specified Plug and Play device property.
func (deviceInfoSet DevInfo) GetDeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) { func (deviceInfoSet DevInfo) DeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) {
return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property) return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property)
} }
@@ -334,6 +321,17 @@ func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoDa
return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers) return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers)
} }
// SetDeviceRegistryPropertyString method sets a Plug and Play device property string for a device.
func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error {
str16, err := windows.UTF16FromString(str)
if err != nil {
return err
}
err = SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, utf16ToBuf(append(str16, 0)))
runtime.KeepAlive(str16)
return err
}
//sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW //sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW
// SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element. // SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element.
@@ -344,16 +342,39 @@ func SetupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInf
return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params) return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params)
} }
// GetDeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element. // DeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element.
func (deviceInfoSet DevInfo) GetDeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) { func (deviceInfoSet DevInfo) DeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) {
return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData) return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData)
} }
//sys setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceInstanceIdW
// SetupDiGetDeviceInstanceId function retrieves the instance ID of the device.
func SetupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (string, error) {
reqSize := uint32(1024)
for {
buf := make([]uint16, reqSize)
err := setupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData, &buf[0], uint32(len(buf)), &reqSize)
if err == windows.ERROR_INSUFFICIENT_BUFFER {
continue
}
if err != nil {
return "", err
}
return windows.UTF16ToString(buf), nil
}
}
// DeviceInstanceID method retrieves the instance ID of the device.
func (deviceInfoSet DevInfo) DeviceInstanceID(deviceInfoData *DevInfoData) (string, error) {
return SetupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData)
}
// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element. // SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element.
//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW //sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW
// GetClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element. // ClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element.
func (deviceInfoSet DevInfo) GetClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error { func (deviceInfoSet DevInfo) ClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error {
return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize) return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize)
} }
@@ -380,7 +401,7 @@ func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (cl
var machineNameUTF16 *uint16 var machineNameUTF16 *uint16
if machineName != "" { if machineName != "" {
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName) machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
if err != nil { if err != nil {
return return
} }
@@ -399,39 +420,31 @@ func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (cl
// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer. // SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer.
func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) { func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) {
classNameUTF16, err := syscall.UTF16PtrFromString(className) classNameUTF16, err := windows.UTF16PtrFromString(className)
if err != nil { if err != nil {
return nil, err return nil, err
} }
const bufCapacity = 4
var buf [bufCapacity]windows.GUID
var bufLen uint32
var machineNameUTF16 *uint16 var machineNameUTF16 *uint16
if machineName != "" { if machineName != "" {
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName) machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
if err != nil { if err != nil {
return nil, err return nil, err
} }
} }
err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], bufCapacity, &bufLen, machineNameUTF16, 0) reqSize := uint32(4)
if err == nil { for {
// The GUID array was sufficiently big. Return its slice. buf := make([]windows.GUID, reqSize)
return buf[:bufLen], nil err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], uint32(len(buf)), &reqSize, machineNameUTF16, 0)
} if err == windows.ERROR_INSUFFICIENT_BUFFER {
continue
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_INSUFFICIENT_BUFFER {
// The GUID array was too small. Now that we got the required size, create another one big enough and retry.
buf := make([]windows.GUID, bufLen)
err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], bufLen, &bufLen, machineNameUTF16, 0)
if err == nil {
return buf[:bufLen], nil
} }
if err != nil {
return nil, err
}
return buf[:reqSize], nil
} }
return nil, err
} }
//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice //sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice
@@ -444,8 +457,8 @@ func SetupDiGetSelectedDevice(deviceInfoSet DevInfo) (*DevInfoData, error) {
return data, setupDiGetSelectedDevice(deviceInfoSet, data) return data, setupDiGetSelectedDevice(deviceInfoSet, data)
} }
// GetSelectedDevice method retrieves the selected device information element in a device information set. // SelectedDevice method retrieves the selected device information element in a device information set.
func (deviceInfoSet DevInfo) GetSelectedDevice() (*DevInfoData, error) { func (deviceInfoSet DevInfo) SelectedDevice() (*DevInfoData, error) {
return SetupDiGetSelectedDevice(deviceInfoSet) return SetupDiGetSelectedDevice(deviceInfoSet)
} }
@@ -456,3 +469,38 @@ func (deviceInfoSet DevInfo) GetSelectedDevice() (*DevInfoData, error) {
func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error { func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error {
return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData) return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData)
} }
//sys cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_List_SizeW
//sys cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_ListW
func CM_Get_Device_Interface_List(deviceID string, interfaceClass *windows.GUID, flags uint32) ([]string, error) {
deviceID16, err := windows.UTF16PtrFromString(deviceID)
if err != nil {
return nil, err
}
var buf []uint16
var buflen uint32
for {
if ret := cm_Get_Device_Interface_List_Size(&buflen, interfaceClass, deviceID16, flags); ret != CR_SUCCESS {
return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
}
buf = make([]uint16, buflen)
if ret := cm_Get_Device_Interface_List(interfaceClass, deviceID16, &buf[0], buflen, flags); ret == CR_SUCCESS {
break
} else if ret != CR_BUFFER_SMALL {
return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
}
}
var interfaces []string
for i := 0; i < len(buf); {
j := i + wcslen(buf[i:])
if i < j {
interfaces = append(interfaces, windows.UTF16ToString(buf[i:j]))
}
i = j + 1
}
if interfaces == nil {
return nil, fmt.Errorf("no interfaces found")
}
return interfaces, nil
}

View File

@@ -6,12 +6,11 @@
package setupapi package setupapi
import ( import (
"runtime"
"strings" "strings"
"syscall"
"testing" "testing"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/tun/wintun/guid"
) )
var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
@@ -23,24 +22,24 @@ func init() {
func TestSetupDiCreateDeviceInfoListEx(t *testing.T) { func TestSetupDiCreateDeviceInfoListEx(t *testing.T) {
devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err == nil { if err != nil {
devInfoList.Close()
} else {
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
} else {
devInfoList.Close()
} }
devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName) devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
if err == nil { if err != nil {
devInfoList.Close()
} else {
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error()) t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
} else {
devInfoList.Close()
} }
devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "") devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "")
if err == nil { if err != nil {
devInfoList.Close()
} else {
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error()) t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error())
} else {
devInfoList.Close()
} }
} }
@@ -51,7 +50,7 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
} }
defer devInfoList.Close() defer devInfoList.Close()
data, err := devInfoList.GetDeviceInfoListDetail() data, err := devInfoList.DeviceInfoListDetail()
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
} else { } else {
@@ -63,7 +62,7 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle") t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle")
} }
if data.GetRemoteMachineName() != "" { if data.RemoteMachineName() != "" {
t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name") t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name")
} }
} }
@@ -74,7 +73,7 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
} }
defer devInfoList.Close() defer devInfoList.Close()
data, err = devInfoList.GetDeviceInfoListDetail() data, err = devInfoList.DeviceInfoListDetail()
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
} else { } else {
@@ -86,14 +85,14 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle") t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle")
} }
if data.GetRemoteMachineName() != computerName { if data.RemoteMachineName() != computerName {
t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name") t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name")
} }
} }
data = &DevInfoListDetailData{} data = &DevInfoListDetailData{}
data.SetRemoteMachineName("foobar") data.SetRemoteMachineName("foobar")
if data.GetRemoteMachineName() != "foobar" { if data.RemoteMachineName() != "foobar" {
t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ") t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ")
} }
} }
@@ -113,7 +112,7 @@ func TestSetupDiCreateDeviceInfo(t *testing.T) {
devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID) devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID)
if err != nil { if err != nil {
// Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed. // Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed.
if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED { if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED {
t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error()) t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error())
} }
} else if devInfoData.ClassGUID != deviceClassNetGUID { } else if devInfoData.ClassGUID != deviceClassNetGUID {
@@ -131,7 +130,7 @@ func TestSetupDiEnumDeviceInfo(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -140,6 +139,11 @@ func TestSetupDiEnumDeviceInfo(t *testing.T) {
if data.ClassGUID != deviceClassNetGUID { if data.ClassGUID != deviceClassNetGUID {
t.Error("SetupDiEnumDeviceInfo returned different class GUID") t.Error("SetupDiEnumDeviceInfo returned different class GUID")
} }
_, err = devInfoList.DeviceInstanceID(data)
if err != nil {
t.Errorf("Error calling SetupDiGetDeviceInstanceId: %s", err.Error())
}
} }
} }
@@ -153,7 +157,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i) deviceData, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -170,7 +174,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
for j := 0; true; j++ { for j := 0; true; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j) driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -209,7 +213,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
selectedDriverData = driverData selectedDriverData = driverData
} }
driverDetailData, err := devInfoList.GetDriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error()) t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error())
} }
@@ -217,10 +221,10 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") { if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") {
t.Error("Invalid HWID compatibitlity reported") t.Error("Invalid HWID compatibitlity reported")
} }
if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.GetHardwareID())) { if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.HardwareID())) {
t.Error("HWID compatibitlity missed") t.Error("HWID compatibitlity missed")
} }
a := driverDetailData.GetCompatIDs() a := driverDetailData.CompatIDs()
for k := range a { for k := range a {
if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) { if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) {
t.Error("HWID compatibitlity missed") t.Error("HWID compatibitlity missed")
@@ -228,7 +232,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
} }
} }
selectedDriverData2, err := devInfoList.GetSelectedDriver(deviceData) selectedDriverData2, err := devInfoList.SelectedDriver(deviceData)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error()) t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error())
} else if *selectedDriverData != *selectedDriverData2 { } else if *selectedDriverData != *selectedDriverData2 {
@@ -238,35 +242,35 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
data := &DrvInfoData{} data := &DrvInfoData{}
data.SetDescription("foobar") data.SetDescription("foobar")
if data.GetDescription() != "foobar" { if data.Description() != "foobar" {
t.Error("DrvInfoData.(Get|Set)Description() differ") t.Error("DrvInfoData.(Get|Set)Description() differ")
} }
data.SetMfgName("foobar") data.SetMfgName("foobar")
if data.GetMfgName() != "foobar" { if data.MfgName() != "foobar" {
t.Error("DrvInfoData.(Get|Set)MfgName() differ") t.Error("DrvInfoData.(Get|Set)MfgName() differ")
} }
data.SetProviderName("foobar") data.SetProviderName("foobar")
if data.GetProviderName() != "foobar" { if data.ProviderName() != "foobar" {
t.Error("DrvInfoData.(Get|Set)ProviderName() differ") t.Error("DrvInfoData.(Get|Set)ProviderName() differ")
} }
} }
func TestSetupDiGetClassDevsEx(t *testing.T) { func TestSetupDiGetClassDevsEx(t *testing.T) {
devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName) devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName)
if err == nil { if err != nil {
devInfoList.Close()
} else {
t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error()) t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
} else {
devInfoList.Close()
} }
devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "") devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "")
if err == nil { if err != nil {
devInfoList.Close() if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
} else {
if errWin, ok := err.(syscall.Errno); !ok || errWin != 87 /*ERROR_INVALID_PARAMETER*/ {
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER") t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER")
} }
} else {
devInfoList.Close()
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
} }
} }
@@ -280,7 +284,7 @@ func TestSetupDiOpenDevRegKey(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -304,47 +308,47 @@ func TestSetupDiGetDeviceRegistryProperty(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
val, err := devInfoList.GetDeviceRegistryProperty(data, SPDRP_CLASS) val, err := devInfoList.DeviceRegistryProperty(data, SPDRP_CLASS)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error())
} else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" { } else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" {
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"") t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"")
} }
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_CLASSGUID) val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CLASSGUID)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
} else if valStr, ok := val.(string); !ok { } else if valStr, ok := val.(string); !ok {
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string") t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string")
} else { } else {
classGUID, err := guid.FromString(valStr) classGUID, err := windows.GUIDFromString(valStr)
if err != nil { if err != nil {
t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error()) t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
} else if *classGUID != deviceClassNetGUID { } else if classGUID != deviceClassNetGUID {
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID) t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID)
} }
} }
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS) val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS)
if err != nil { if err != nil {
// Some devices have no SPDRP_COMPATIBLEIDS. // Some devices have no SPDRP_COMPATIBLEIDS.
if errWin, ok := err.(syscall.Errno); !ok || errWin != 13 /*windows.ERROR_INVALID_DATA*/ { if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_DATA {
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error())
} }
} }
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_CONFIGFLAGS) val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CONFIGFLAGS)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error())
} }
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA) val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error())
} }
@@ -361,13 +365,13 @@ func TestSetupDiGetDeviceInstallParams(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
_, err = devInfoList.GetDeviceInstallParams(data) _, err = devInfoList.DeviceInstallParams(data)
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error())
} }
@@ -375,7 +379,7 @@ func TestSetupDiGetDeviceInstallParams(t *testing.T) {
params := &DevInstallParams{} params := &DevInstallParams{}
params.SetDriverPath("foobar") params.SetDriverPath("foobar")
if params.GetDriverPath() != "foobar" { if params.DriverPath() != "foobar" {
t.Error("DevInstallParams.(Get|Set)DriverPath() differ") t.Error("DevInstallParams.(Get|Set)DriverPath() differ")
} }
} }
@@ -396,12 +400,12 @@ func TestSetupDiClassNameFromGuidEx(t *testing.T) {
} }
_, err = SetupDiClassNameFromGuidEx(nil, "") _, err = SetupDiClassNameFromGuidEx(nil, "")
if err == nil { if err != nil {
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail") if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER {
} else {
if errWin, ok := err.(syscall.Errno); !ok || errWin != 1784 /*ERROR_INVALID_USER_BUFFER*/ {
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER") t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER")
} }
} else {
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail")
} }
} }
@@ -440,7 +444,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -451,7 +455,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error()) t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error())
} }
data2, err := devInfoList.GetSelectedDevice() data2, err := devInfoList.SelectedDevice()
if err != nil { if err != nil {
t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error()) t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error())
} else if *data != *data2 { } else if *data != *data2 {
@@ -460,18 +464,18 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
} }
err = devInfoList.SetSelectedDevice(nil) err = devInfoList.SetSelectedDevice(nil)
if err == nil { if err != nil {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail") if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
} else {
if errWin, ok := err.(syscall.Errno); !ok || errWin != 87 /*ERROR_INVALID_PARAMETER*/ {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
} }
} else {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail")
} }
} }
func TestUTF16ToBuf(t *testing.T) { func TestUTF16ToBuf(t *testing.T) {
buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef} buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef}
buf2 := UTF16ToBuf(buf) buf2 := utf16ToBuf(buf)
if len(buf)*2 != len(buf2) || if len(buf)*2 != len(buf2) ||
cap(buf)*2 != cap(buf2) || cap(buf)*2 != cap(buf2) ||
buf2[0] != 0x23 || buf2[1] != 0x01 || buf2[0] != 0x23 || buf2[1] != 0x01 ||
@@ -480,4 +484,5 @@ func TestUTF16ToBuf(t *testing.T) {
buf2[6] != 0xef || buf2[7] != 0xcd { buf2[6] != 0xef || buf2[7] != 0xcd {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
} }
runtime.KeepAlive(buf)
} }

View File

@@ -7,7 +7,6 @@ package setupapi
import ( import (
"strings" "strings"
"syscall"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -28,6 +27,7 @@ const (
// Define maximum string length constants // Define maximum string length constants
// //
const ( const (
ANYSIZE_ARRAY = 1
LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF. LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF.
MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions). MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions).
MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters. MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters.
@@ -59,18 +59,18 @@ type DevInfoData struct {
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass). // DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct { type DevInfoListDetailData struct {
size uint32 size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID ClassGUID windows.GUID
RemoteMachineHandle windows.Handle RemoteMachineHandle windows.Handle
remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16 remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16
} }
func (data *DevInfoListDetailData) GetRemoteMachineName() string { func (data *DevInfoListDetailData) RemoteMachineName() string {
return windows.UTF16ToString(data.remoteMachineName[:]) return windows.UTF16ToString(data.remoteMachineName[:])
} }
func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error { func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error {
str, err := syscall.UTF16FromString(remoteMachineName) str, err := windows.UTF16FromString(remoteMachineName)
if err != nil { if err != nil {
return err return err
} }
@@ -137,12 +137,12 @@ type DevInstallParams struct {
driverPath [windows.MAX_PATH]uint16 driverPath [windows.MAX_PATH]uint16
} }
func (params *DevInstallParams) GetDriverPath() string { func (params *DevInstallParams) DriverPath() string {
return windows.UTF16ToString(params.driverPath[:]) return windows.UTF16ToString(params.driverPath[:])
} }
func (params *DevInstallParams) SetDriverPath(driverPath string) error { func (params *DevInstallParams) SetDriverPath(driverPath string) error {
str, err := syscall.UTF16FromString(driverPath) str, err := windows.UTF16FromString(driverPath)
if err != nil { if err != nil {
return err return err
} }
@@ -267,15 +267,34 @@ func MakeClassInstallHeader(installFunction DI_FUNCTION) *ClassInstallHeader {
return hdr return hdr
} }
// DICS_STATE specifies values indicating a change in a device's state
type DICS_STATE uint32
const (
DICS_ENABLE DICS_STATE = 0x00000001 // The device is being enabled.
DICS_DISABLE DICS_STATE = 0x00000002 // The device is being disabled.
DICS_PROPCHANGE DICS_STATE = 0x00000003 // The properties of the device have changed.
DICS_START DICS_STATE = 0x00000004 // The device is being started (if the request is for the currently active hardware profile).
DICS_STOP DICS_STATE = 0x00000005 // The device is being stopped. The driver stack will be unloaded and the CSCONFIGFLAG_DO_NOT_START flag will be set for the device.
)
// DICS_FLAG specifies the scope of a device property change // DICS_FLAG specifies the scope of a device property change
type DICS_FLAG uint32 type DICS_FLAG uint32
const ( const (
DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles
DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only
DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow (obsolete)
) )
// PropChangeParams is a structure corresponding to a DIF_PROPERTYCHANGE install function.
type PropChangeParams struct {
ClassInstallHeader ClassInstallHeader
StateChange DICS_STATE
Scope DICS_FLAG
HwProfile uint32
}
// DI_REMOVEDEVICE specifies the scope of the device removal // DI_REMOVEDEVICE specifies the scope of the device removal
type DI_REMOVEDEVICE uint32 type DI_REMOVEDEVICE uint32
@@ -303,12 +322,12 @@ type DrvInfoData struct {
DriverVersion uint64 DriverVersion uint64
} }
func (data *DrvInfoData) GetDescription() string { func (data *DrvInfoData) Description() string {
return windows.UTF16ToString(data.description[:]) return windows.UTF16ToString(data.description[:])
} }
func (data *DrvInfoData) SetDescription(description string) error { func (data *DrvInfoData) SetDescription(description string) error {
str, err := syscall.UTF16FromString(description) str, err := windows.UTF16FromString(description)
if err != nil { if err != nil {
return err return err
} }
@@ -316,12 +335,12 @@ func (data *DrvInfoData) SetDescription(description string) error {
return nil return nil
} }
func (data *DrvInfoData) GetMfgName() string { func (data *DrvInfoData) MfgName() string {
return windows.UTF16ToString(data.mfgName[:]) return windows.UTF16ToString(data.mfgName[:])
} }
func (data *DrvInfoData) SetMfgName(mfgName string) error { func (data *DrvInfoData) SetMfgName(mfgName string) error {
str, err := syscall.UTF16FromString(mfgName) str, err := windows.UTF16FromString(mfgName)
if err != nil { if err != nil {
return err return err
} }
@@ -329,12 +348,12 @@ func (data *DrvInfoData) SetMfgName(mfgName string) error {
return nil return nil
} }
func (data *DrvInfoData) GetProviderName() string { func (data *DrvInfoData) ProviderName() string {
return windows.UTF16ToString(data.providerName[:]) return windows.UTF16ToString(data.providerName[:])
} }
func (data *DrvInfoData) SetProviderName(providerName string) error { func (data *DrvInfoData) SetProviderName(providerName string) error {
str, err := syscall.UTF16FromString(providerName) str, err := windows.UTF16FromString(providerName)
if err != nil { if err != nil {
return err return err
} }
@@ -370,7 +389,7 @@ func (data *DrvInfoData) IsNewer(driverDate windows.Filetime, driverVersion uint
// DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure) // DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure)
type DrvInfoDetailData struct { type DrvInfoDetailData struct {
size uint32 // On input, this must be exactly the sizeof(DrvInfoDetailData). On output, we set this member to the actual size of structure data. size uint32 // Warning: unsafe.Sizeof(DrvInfoDetailData) > sizeof(SP_DRVINFO_DETAIL_DATA) when GOARCH == 386 => use sizeofDrvInfoDetailData const.
InfDate windows.Filetime InfDate windows.Filetime
compatIDsOffset uint32 compatIDsOffset uint32
compatIDsLength uint32 compatIDsLength uint32
@@ -378,22 +397,22 @@ type DrvInfoDetailData struct {
sectionName [LINE_LEN]uint16 sectionName [LINE_LEN]uint16
infFileName [windows.MAX_PATH]uint16 infFileName [windows.MAX_PATH]uint16
drvDescription [LINE_LEN]uint16 drvDescription [LINE_LEN]uint16
hardwareID [1]uint16 hardwareID [ANYSIZE_ARRAY]uint16
} }
func (data *DrvInfoDetailData) GetSectionName() string { func (data *DrvInfoDetailData) SectionName() string {
return windows.UTF16ToString(data.sectionName[:]) return windows.UTF16ToString(data.sectionName[:])
} }
func (data *DrvInfoDetailData) GetInfFileName() string { func (data *DrvInfoDetailData) InfFileName() string {
return windows.UTF16ToString(data.infFileName[:]) return windows.UTF16ToString(data.infFileName[:])
} }
func (data *DrvInfoDetailData) GetDrvDescription() string { func (data *DrvInfoDetailData) DrvDescription() string {
return windows.UTF16ToString(data.drvDescription[:]) return windows.UTF16ToString(data.drvDescription[:])
} }
func (data *DrvInfoDetailData) GetHardwareID() string { func (data *DrvInfoDetailData) HardwareID() string {
if data.compatIDsOffset > 1 { if data.compatIDsOffset > 1 {
bufW := data.getBuf() bufW := data.getBuf()
return windows.UTF16ToString(bufW[:wcslen(bufW)]) return windows.UTF16ToString(bufW[:wcslen(bufW)])
@@ -402,7 +421,7 @@ func (data *DrvInfoDetailData) GetHardwareID() string {
return "" return ""
} }
func (data *DrvInfoDetailData) GetCompatIDs() []string { func (data *DrvInfoDetailData) CompatIDs() []string {
a := make([]string, 0) a := make([]string, 0)
if data.compatIDsLength > 0 { if data.compatIDsLength > 0 {
@@ -433,10 +452,10 @@ func (data *DrvInfoDetailData) getBuf() []uint16 {
// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list. // IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list.
func (data *DrvInfoDetailData) IsCompatible(hwid string) bool { func (data *DrvInfoDetailData) IsCompatible(hwid string) bool {
hwidLC := strings.ToLower(hwid) hwidLC := strings.ToLower(hwid)
if strings.ToLower(data.GetHardwareID()) == hwidLC { if strings.ToLower(data.HardwareID()) == hwidLC {
return true return true
} }
a := data.GetCompatIDs() a := data.CompatIDs()
for i := range a { for i := range a {
if strings.ToLower(a[i]) == hwidLC { if strings.ToLower(a[i]) == hwidLC {
return true return true
@@ -537,3 +556,13 @@ const (
SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals
) )
const (
CR_SUCCESS = 0x0
CR_BUFFER_SMALL = 0x1a
)
const (
CM_GET_DEVICE_INTERFACE_LIST_PRESENT = 0 // only currently 'live' device interfaces
CM_GET_DEVICE_INTERFACE_LIST_ALL_DEVICES = 1 // all registered device interfaces, live or not
)

View File

@@ -0,0 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package setupapi
const (
sizeofDevInfoListDetailData uint32 = 550
sizeofDrvInfoDetailData uint32 = 1570
)

View File

@@ -0,0 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package setupapi
const (
sizeofDevInfoListDetailData uint32 = 560
sizeofDrvInfoDetailData uint32 = 1584
)

View File

@@ -38,32 +38,36 @@ func errnoErr(e syscall.Errno) error {
var ( var (
modsetupapi = windows.NewLazySystemDLL("setupapi.dll") modsetupapi = windows.NewLazySystemDLL("setupapi.dll")
modCfgMgr32 = windows.NewLazySystemDLL("CfgMgr32.dll")
procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW") procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW")
procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW") procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW")
procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW") procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW")
procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo") procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo")
procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList") procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList")
procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList") procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList")
procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch") procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch")
procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW") procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW")
procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW") procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW")
procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW") procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW")
procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW") procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW")
procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList") procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList")
procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW") procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW")
procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller") procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller")
procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey") procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey")
procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW") procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW")
procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW") procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW")
procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW") procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW")
procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW") procSetupDiGetDeviceInstanceIdW = modsetupapi.NewProc("SetupDiGetDeviceInstanceIdW")
procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW") procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW")
procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW") procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW")
procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW") procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW")
procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW") procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW")
procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice") procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW")
procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice") procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice")
procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice")
procCM_Get_Device_Interface_List_SizeW = modCfgMgr32.NewProc("CM_Get_Device_Interface_List_SizeW")
procCM_Get_Device_Interface_ListW = modCfgMgr32.NewProc("CM_Get_Device_Interface_ListW")
) )
func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) { func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
@@ -285,6 +289,18 @@ func setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInf
return return
} }
func setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procSetupDiGetDeviceInstanceIdW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(instanceId)), uintptr(instanceIdSize), uintptr(unsafe.Pointer(instanceIdRequiredSize)), 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) { func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0) r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0)
if r1 == 0 { if r1 == 0 {
@@ -368,3 +384,15 @@ func SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData
} }
return return
} }
func cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) {
r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_List_SizeW.Addr(), 4, uintptr(unsafe.Pointer(len)), uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(flags), 0, 0)
ret = uint32(r0)
return
}
func cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) {
r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_ListW.Addr(), 5, uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(unsafe.Pointer(buffer)), uintptr(bufferLen), uintptr(flags), 0)
ret = uint32(r0)
return
}

View File

@@ -14,7 +14,7 @@ import (
func TestSetupDiDestroyDeviceInfoList(t *testing.T) { func TestSetupDiDestroyDeviceInfoList(t *testing.T) {
err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle)) err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle))
if errWin, ok := err.(syscall.Errno); !ok || errWin != 6 /*ERROR_INVALID_HANDLE*/ { if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_HANDLE {
t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE") t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE")
} }
} }

View File

@@ -8,98 +8,89 @@ package wintun
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.zx2c4.com/wireguard/tun/wintun/netshell"
"strings" "strings"
"syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/tun/wintun/guid"
"golang.zx2c4.com/wireguard/tun/wintun/netshell"
registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry"
"golang.zx2c4.com/wireguard/tun/wintun/setupapi" "golang.zx2c4.com/wireguard/tun/wintun/setupapi"
) )
// // Wintun is a handle of a Wintun adapter.
// Wintun is a handle of a Wintun adapter
//
type Wintun struct { type Wintun struct {
CfgInstanceID windows.GUID cfgInstanceID windows.GUID
LUIDIndex uint32 devInstanceID string
IfType uint32 luidIndex uint32
ifType uint32
} }
var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}} var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
var deviceInterfaceNetGUID = windows.GUID{Data1: 0xcac88484, Data2: 0x7515, Data3: 0x4c03, Data4: [8]byte{0x82, 0xe6, 0x71, 0xa8, 0x7a, 0xba, 0xc3, 0x61}}
const hardwareID = "Wintun" const (
const enumerator = "" hardwareID = "Wintun"
const machineName = "" waitForRegistryTimeout = time.Second * 10
)
// // makeWintun creates a Wintun interface handle and populates it from the device's registry key.
// MakeWintun creates interface handle and populates it from device registry key func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (*Wintun, error) {
//
func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (*Wintun, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ) key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil { if err != nil {
return nil, errors.New("Device-specific registry key open failed: " + err.Error()) return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
} }
defer key.Close() defer key.Close()
var valueStr string
var valueType uint32
// Read the NetCfgInstanceId value. // Read the NetCfgInstanceId value.
valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId") valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId")
if err != nil { if err != nil {
return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error()) return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err)
}
if valueType != registry.SZ {
return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
} }
// Convert to windows.GUID. // Convert to GUID.
ifid, err := guid.FromString(valueStr) ifid, err := windows.GUIDFromString(valueStr)
if err != nil { if err != nil {
return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: %q)", valueStr) return nil, fmt.Errorf("NetCfgInstanceId registry value is not a GUID (expected: \"{...}\", provided: %q)", valueStr)
} }
// Read the NetLuidIndex value. // Read the NetLuidIndex value.
luidIdx, valueType, err := key.GetIntegerValue("NetLuidIndex") luidIdx, _, err := key.GetIntegerValue("NetLuidIndex")
if err != nil { if err != nil {
return nil, errors.New("RegQueryValue(\"NetLuidIndex\") failed: " + err.Error()) return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err)
} }
// Read the NetLuidIndex value. // Read the NetLuidIndex value.
ifType, valueType, err := key.GetIntegerValue("*IfType") ifType, _, err := key.GetIntegerValue("*IfType")
if err != nil { if err != nil {
return nil, errors.New("RegQueryValue(\"*IfType\") failed: " + err.Error()) return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
}
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData)
if err != nil {
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
} }
return &Wintun{ return &Wintun{
CfgInstanceID: *ifid, cfgInstanceID: ifid,
LUIDIndex: uint32(luidIdx), devInstanceID: instanceID,
IfType: uint32(ifType), luidIndex: uint32(luidIdx),
ifType: uint32(ifType),
}, nil }, nil
} }
// // GetInterface finds a Wintun interface by its name. This function returns
// GetInterface finds interface by name. // the interface if found, or windows.ERROR_OBJECT_NOT_FOUND otherwise. If
// // the interface is found but not a Wintun-class, this function returns
// hwndParent is a handle to the top-level window to use for any user // windows.ERROR_ALREADY_EXISTS.
// interface that is related to non-device-specific actions (such as a select- func GetInterface(ifname string) (*Wintun, error) {
// device dialog box that uses the global class driver list). This handle is
// optional and can be 0. If a specific top-level window is not required, set
// hwndParent to 0.
//
// Function returns interface if found, or nil otherwise. If the interface is
// found but not Wintun-class, the function returns interface and an error.
//
func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName) devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, errors.New(fmt.Sprintf("SetupDiGetClassDevsEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
} }
defer devInfoList.Close() defer devInfoList.Close()
@@ -111,152 +102,141 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
ifname = strings.ToLower(ifname) ifname = strings.ToLower(ifname)
for index := 0; ; index++ { for index := 0; ; index++ {
// Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index) deviceData, err := devInfoList.EnumDeviceInfo(index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if err == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
// Get interface ID. wintun, err := makeWintun(devInfoList, deviceData)
wintun, err := MakeWintun(devInfoList, deviceData)
if err != nil { if err != nil {
continue continue
} }
// Get interface name. // TODO: is there a better way than comparing ifnames?
ifname2, err := wintun.GetInterfaceName() ifname2, err := wintun.InterfaceName()
if err != nil { if err != nil {
continue continue
} }
if ifname == strings.ToLower(ifname2) { if ifname == strings.ToLower(ifname2) {
// Interface name found. Check its driver. err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
const driverType = setupapi.SPDIT_COMPATDRIVER
err = devInfoList.BuildDriverInfoList(deviceData, driverType)
if err != nil { if err != nil {
return nil, errors.New("SetupDiBuildDriverInfoList failed: " + err.Error()) return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
} }
defer devInfoList.DestroyDriverInfoList(deviceData, driverType) defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
for index := 0; ; index++ { for index := 0; ; index++ {
// Get a driver from the list. driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if err == windows.ERROR_NO_MORE_ITEMS {
break break
} }
// Something is wrong with this driver. Skip it.
continue continue
} }
// Get driver info details. // Get driver info details.
driverDetailData, err := devInfoList.GetDriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
if err != nil { if err != nil {
// Something is wrong with this driver. Skip it.
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
// Matching hardware ID found.
return wintun, nil return wintun, nil
} }
} }
// This interface is not using Wintun driver. // This interface is not using Wintun driver.
return nil, errors.New("Foreign network interface with the same name exists") return nil, windows.ERROR_ALREADY_EXISTS
} }
} }
return nil, nil return nil, windows.ERROR_OBJECT_NOT_FOUND
} }
// CreateInterface creates a Wintun interface. description is a string that
// supplies the text description of the device. The description is optional
// and can be "". requestedGUID is the GUID of the created network interface,
// which then influences NLA generation deterministically. If it is set to nil,
// the GUID is chosen by the system at random, and hence a new NLA entry is
// created for each new interface. It is called "requested" GUID because the
// API it uses is completely undocumented, and so there could be minor
// interesting complications with its usage. This function returns the network
// interface ID and a flag if reboot is required.
// //
// CreateInterface creates a TUN interface. func CreateInterface(description string, requestedGUID *windows.GUID) (wintun *Wintun, rebootRequired bool, err error) {
//
// description is a string that supplies the text description of the device.
// description is optional and can be "".
//
// hwndParent is a handle to the top-level window to use for any user
// interface that is related to non-device-specific actions (such as a select-
// device dialog box that uses the global class driver list). This handle is
// optional and can be 0. If a specific top-level window is not required, set
// hwndParent to 0.
//
// Function returns the network interface ID and a flag if reboot is required.
//
func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, error) {
// Create an empty device info set for network adapter device class. // Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, hwndParent, machineName) devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err != nil { if err != nil {
return nil, false, errors.New(fmt.Sprintf("SetupDiCreateDeviceInfoListEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
return
} }
defer devInfoList.Close()
// Get the device class name from GUID. // Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, machineName) className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
if err != nil { if err != nil {
return nil, false, errors.New(fmt.Sprintf("SetupDiClassNameFromGuidEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) err = fmt.Errorf("SetupDiClassNameFromGuidEx(%v) failed: %v", deviceClassNetGUID, err)
return
} }
// Create a new device info element and add it to the device info set. // Create a new device info element and add it to the device info set.
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, hwndParent, setupapi.DICD_GENERATE_ID) deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, 0, setupapi.DICD_GENERATE_ID)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiCreateDeviceInfo failed: " + err.Error()) err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
return
}
err = setQuietInstall(devInfoList, deviceData)
if err != nil {
err = fmt.Errorf("Setting quiet installation failed: %v", err)
return
} }
// Set a device information element as the selected member of a device information set. // Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData) err = devInfoList.SetSelectedDevice(deviceData)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiSetSelectedDevice failed: " + err.Error()) err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
return
} }
// Set Plug&Play device hardware ID property. // Set Plug&Play device hardware ID property.
hwid, err := syscall.UTF16FromString(hardwareID) err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil { if err != nil {
return nil, false, err // syscall.UTF16FromString(hardwareID) should never fail: hardwareID is const string without NUL chars. err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
} return
err = devInfoList.SetDeviceRegistryProperty(deviceData, setupapi.SPDRP_HARDWAREID, setupapi.UTF16ToBuf(append(hwid, 0)))
if err != nil {
return nil, false, errors.New("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: " + err.Error())
} }
// Search for the driver. err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
const driverType = setupapi.SPDIT_CLASSDRIVER
err = devInfoList.BuildDriverInfoList(deviceData, driverType)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiBuildDriverInfoList failed: " + err.Error()) err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
return
} }
defer devInfoList.DestroyDriverInfoList(deviceData, driverType) defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
driverDate := windows.Filetime{} driverDate := windows.Filetime{}
driverVersion := uint64(0) driverVersion := uint64(0)
for index := 0; ; index++ { for index := 0; ; index++ { // TODO: This loop takes ~600ms
// Get a driver from the list. driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if err == windows.ERROR_NO_MORE_ITEMS {
break break
} }
// Something is wrong with this driver. Skip it.
continue continue
} }
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match. // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
if driverData.IsNewer(driverDate, driverVersion) { if driverData.IsNewer(driverDate, driverVersion) {
// Get driver info details. driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfoList.GetDriverInfoDetail(deviceData, driverData)
if err != nil { if err != nil {
// Something is wrong with this driver. Skip it.
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
// Matching hardware ID found. Select the driver.
err := devInfoList.SetSelectedDriver(deviceData, driverData) err := devInfoList.SetSelectedDriver(deviceData, driverData)
if err != nil { if err != nil {
// Something is wrong with this driver. Skip it.
continue continue
} }
@@ -267,244 +247,415 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
} }
if driverVersion == 0 { if driverVersion == 0 {
return nil, false, fmt.Errorf("No driver for device %q installed", hardwareID) err = fmt.Errorf("No driver for device %q installed", hardwareID)
return
} }
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData) err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: " + err.Error()) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return
} }
// Register device co-installers if any. (Ignore errors) // Register device co-installers if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData) devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData)
// Install interfaces if any. (Ignore errors) var key registry.Key
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData) const pollTimeout = time.Millisecond * 50
for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
var wintun *Wintun if i != 0 {
var rebootRequired bool time.Sleep(pollTimeout)
// Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
if err != nil {
err = errors.New("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: " + err.Error())
}
if err == nil {
// Check if a system reboot is required. (Ignore errors)
if ret, _ := checkReboot(devInfoList, deviceData); ret {
rebootRequired = true
} }
key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
// Get network interface. DIF_INSTALLDEVICE returns almost immediately, while the device if err == nil {
// installation continues in the background. It might take a while, before all registry
// keys and values are populated.
for numAttempts := 0; numAttempts < 30; numAttempts++ {
wintun, err = MakeWintun(devInfoList, deviceData)
if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND {
// Wait and retry. TODO: Wait for a cancellable event instead.
err = errors.New("Time-out waiting for adapter to get ready")
time.Sleep(time.Second / 4)
continue
}
}
break break
} }
} }
if err == nil {
return wintun, rebootRequired, nil
}
// The interface failed to install, or the interface ID was unobtainable. Clean-up.
removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
}
// Set class installer parameters for DIF_REMOVE.
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
// Call appropriate class installer.
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil {
// Check if a system reboot is required. (Ignore errors)
if ret, _ := checkReboot(devInfoList, deviceData); ret {
rebootRequired = true
}
}
}
return nil, rebootRequired, err
}
//
// DeleteInterface deletes a TUN interface.
//
// hwndParent is a handle to the top-level window to use for any user
// interface that is related to non-device-specific actions (such as a select-
// device dialog box that uses the global class driver list). This handle is
// optional and can be 0. If a specific top-level window is not required, set
// hwndParent to 0.
//
// Function returns true if the interface was found and deleted and a flag if
// reboot is required.
//
func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
// Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName)
if err != nil { if err != nil {
return false, false, errors.New(fmt.Sprintf("SetupDiGetClassDevsEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) err = fmt.Errorf("SetupDiOpenDevRegKey failed: %v", err)
return
} }
defer devInfoList.Close() defer key.Close()
if requestedGUID != nil {
// Iterate. err = key.SetStringValue("NetSetupAnticipatedInstanceId", requestedGUID.String())
for index := 0; ; index++ {
// Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { err = fmt.Errorf("SetStringValue(NetSetupAnticipatedInstanceId) failed: %v", err)
break return
}
continue
} }
}
// Get interface ID. // Install interfaces if any. (Ignore errors)
wintun2, err := MakeWintun(devInfoList, deviceData) devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
defer func() {
if err != nil { if err != nil {
continue // The interface failed to install, or the interface ID was unobtainable. Clean-up.
}
if wintun.CfgInstanceID == wintun2.CfgInstanceID {
// Remove the device.
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
if err != nil { // Call appropriate class installer.
return false, false, errors.New("SetupDiSetClassInstallParams failed: " + err.Error()) if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
}
} }
// Call appropriate class installer. wintun = nil
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
if err != nil {
return false, false, errors.New("SetupDiCallClassInstaller failed: " + err.Error())
}
// Check if a system reboot is required. (Ignore errors)
if ret, _ := checkReboot(devInfoList, deviceData); ret {
return true, true, nil
}
return true, false, nil
} }
}()
// Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
return
}
rebootRequired = checkReboot(devInfoList, deviceData)
// DIF_INSTALLDEVICE returns almost immediately, while the device installation
// continues in the background. It might take a while, before all registry
// keys and values are populated.
_, err = registryEx.GetStringValueWait(key, "NetCfgInstanceId", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(NetCfgInstanceId) failed: %v", err)
return
}
_, err = registryEx.GetIntegerValueWait(key, "NetLuidIndex", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetIntegerValueWait(NetLuidIndex) failed: %v", err)
return
}
_, err = registryEx.GetIntegerValueWait(key, "*IfType", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetIntegerValueWait(*IfType) failed: %v", err)
return
} }
return false, false, nil // Get network interface.
wintun, err = makeWintun(devInfoList, deviceData)
if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err)
return
}
// Wait for network registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.netRegKeyName(),
registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err)
return
}
defer key.Close()
_, err = registryEx.GetStringValueWait(key, "Name", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(Name) failed: %v", err)
return
}
_, err = registryEx.GetStringValueWait(key, "PnPInstanceId", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(PnPInstanceId) failed: %v", err)
return
}
// Wait for TCP/IP adapter registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", wintun.tcpipAdapterRegKeyName(), err)
return
}
defer key.Close()
_, err = registryEx.GetStringValueWait(key, "IpConfig", waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("GetStringValueWait(IpConfig) failed: %v", err)
return
}
tcpipInterfaceRegKeyName, err := wintun.tcpipInterfaceRegKeyName()
if err != nil {
err = fmt.Errorf("tcpipInterfaceRegKeyName failed: %v", err)
return
}
// Wait for TCP/IP interface registry key to emerge.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
tcpipInterfaceRegKeyName, registry.QUERY_VALUE,
waitForRegistryTimeout)
if err != nil {
err = fmt.Errorf("OpenKeyWait(HKLM\\%s) failed: %v", tcpipInterfaceRegKeyName, err)
return
}
key.Close()
//
// All the registry keys and values we're relying on are present now.
//
// Disable dead gateway detection on our interface.
key, err = registry.OpenKey(registry.LOCAL_MACHINE, tcpipInterfaceRegKeyName, registry.SET_VALUE)
if err != nil {
err = fmt.Errorf("Error opening interface-specific TCP/IP network registry key: %v", err)
return
}
key.SetDWordValue("EnableDeadGWDetect", 0)
key.Close()
return
} }
// // DeleteInterface deletes a Wintun interface. This function succeeds
// FlushInterface removes all properties from the interface and gives it only a very // if the interface was not found. It returns a bool indicating whether
// vanilla IPv4 and IPv6 configuration with no addresses of any sort assigned. // a reboot is required.
// func (wintun *Wintun) DeleteInterface() (rebootRequired bool, err error) {
func (wintun *Wintun) FlushInterface() error { devInfoList, deviceData, err := wintun.deviceData()
//TODO: implement me! if err == windows.ERROR_OBJECT_NOT_FOUND {
return nil return false, nil
} }
//
// checkReboot checks device install parameters if a system reboot is required.
//
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (bool, error) {
devInstallParams, err := deviceInfoSet.GetDeviceInstallParams(deviceInfoData)
if err != nil { if err != nil {
return false, err return false, err
} }
defer devInfoList.Close()
if (devInstallParams.Flags & (setupapi.DI_NEEDREBOOT | setupapi.DI_NEEDRESTART)) != 0 { // Remove the device.
return true, nil removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
} }
return false, nil // Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil {
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
}
// Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
if err != nil {
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
}
return checkReboot(devInfoList, deviceData), nil
} }
// // DeleteAllInterfaces deletes all Wintun interfaces, and returns which
// GetInterfaceName returns network interface name. // ones it deleted, whether a reboot is required after, and which errors
// // occurred during the process.
func (wintun *Wintun) GetInterfaceName() (string, error) { func DeleteAllInterfaces() (deviceInstancesDeleted []uint32, rebootRequired bool, errors []error) {
key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE) devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return "", errors.New("Network-specific registry key open failed: " + err.Error()) return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
}
defer devInfoList.Close()
for i := 0; ; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
}
continue
}
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
continue
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
isWintun := false
for j := 0; ; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
}
continue
}
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
if err != nil {
continue
}
if driverDetailData.IsCompatible(hardwareID) {
isWintun = true
break
}
}
if !isWintun {
continue
}
err = setQuietInstall(devInfoList, deviceData)
if err != nil {
errors = append(errors, err)
continue
}
inst := deviceData.DevInst
removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
}
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil {
errors = append(errors, err)
continue
}
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
if err != nil {
errors = append(errors, err)
continue
}
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
}
return
}
// checkReboot checks device install parameters if a system reboot is required.
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
if err != nil {
return false
}
return (devInstallParams.Flags & (setupapi.DI_NEEDREBOOT | setupapi.DI_NEEDRESTART)) != 0
}
// setQuietInstall sets device install parameters for a quiet installation
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
if err != nil {
return err
}
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams)
}
// InterfaceName returns the name of the Wintun interface.
func (wintun *Wintun) InterfaceName() (string, error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.QUERY_VALUE)
if err != nil {
return "", fmt.Errorf("Network-specific registry key open failed: %v", err)
} }
defer key.Close() defer key.Close()
// Get the interface name. // Get the interface name.
return getRegStringValue(key, "Name") return registryEx.GetStringValue(key, "Name")
} }
// // SetInterfaceName sets name of the Wintun interface.
// SetInterfaceName sets network interface name.
//
func (wintun *Wintun) SetInterfaceName(ifname string) error { func (wintun *Wintun) SetInterfaceName(ifname string) error {
// We open the registry key before calling HrRename, because the registry open will wait until the key exists.
key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
if err != nil {
return errors.New("Network-specific registry key open failed: " + err.Error())
}
defer key.Close()
// We have to tell the various runtime COM services about the new name too. We ignore the // We have to tell the various runtime COM services about the new name too. We ignore the
// error because netshell isn't available on servercore. // error because netshell isn't available on servercore.
// TODO: netsh.exe falls back to NciSetConnection in this case. If somebody complains, maybe // TODO: netsh.exe falls back to NciSetConnection in this case. If somebody complains, maybe
// we should do the same. // we should do the same.
_ = netshell.HrRenameConnection(&wintun.CfgInstanceID, windows.StringToUTF16Ptr(ifname)) netshell.HrRenameConnection(&wintun.cfgInstanceID, windows.StringToUTF16Ptr(ifname))
// Set the interface name. The above line should have done this too, but in case it failed, we force it. // Set the interface name. The above line should have done this too, but in case it failed, we force it.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.netRegKeyName(), registry.SET_VALUE)
if err != nil {
return fmt.Errorf("Network-specific registry key open failed: %v", err)
}
defer key.Close()
return key.SetStringValue("Name", ifname) return key.SetStringValue("Name", ifname)
} }
// // netRegKeyName returns the interface-specific network registry key name.
// GetNetRegKeyName returns interface-specific network registry key name. func (wintun *Wintun) netRegKeyName() string {
// return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", deviceClassNetGUID, wintun.cfgInstanceID)
func (wintun *Wintun) GetNetRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(&wintun.CfgInstanceID))
} }
// // tcpipAdapterRegKeyName returns the adapter-specific TCP/IP network registry key name.
// getRegStringValue function reads a string value from registry. func (wintun *Wintun) tcpipAdapterRegKeyName() string {
// return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
// If the value type is REG_EXPAND_SZ the environment variables are expanded. }
// Should expanding fail, original string value and nil error are returned.
// // tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
func getRegStringValue(key registry.Key, name string) (string, error) { func (wintun *Wintun) tcpipInterfaceRegKeyName() (path string, err error) {
// Read string value. key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
value, valueType, err := keyGetStringValueRetry(key, name)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
} }
paths, _, err := key.GetStringsValue("IpConfig")
if valueType != registry.EXPAND_SZ { key.Close()
// Value does not require expansion.
return value, nil
}
valueExp, err := registry.ExpandString(value)
if err != nil { if err != nil {
// Expanding failed: return original sting value. return "", fmt.Errorf("Error reading IpConfig registry key: %v", err)
return value, nil }
if len(paths) == 0 {
return "", errors.New("No TCP/IP interfaces found on adapter")
}
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
}
// deviceData returns TUN device info list handle and interface device info
// data. The device info list handle must be closed after use. In case the
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
func (wintun *Wintun) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
// Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
} }
// Return expanded value. for index := 0; ; index++ {
return valueExp, nil deviceData, err := devInfoList.EnumDeviceInfo(index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
}
continue
}
// Get interface ID.
// TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData)
if err != nil {
continue
}
if wintun.cfgInstanceID == wintun2.cfgInstanceID {
err = setQuietInstall(devInfoList, deviceData)
if err != nil {
devInfoList.Close()
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
}
return devInfoList, deviceData, nil
}
}
devInfoList.Close()
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
} }
// // AdapterHandle returns a handle to the adapter device object.
// DataFileName returns Wintun device data pipe name. func (wintun *Wintun) AdapterHandle() (windows.Handle, error) {
// interfaces, err := setupapi.CM_Get_Device_Interface_List(wintun.devInstanceID, &deviceInterfaceNetGUID, setupapi.CM_GET_DEVICE_INTERFACE_LIST_PRESENT)
func (wintun *Wintun) DataFileName() string { if err != nil {
return fmt.Sprintf("\\\\.\\Global\\WINTUN%d", wintun.LUIDIndex) return windows.InvalidHandle, err
}
handle, err := windows.CreateFile(windows.StringToUTF16Ptr(interfaces[0]), windows.GENERIC_READ|windows.GENERIC_WRITE, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE|windows.FILE_SHARE_DELETE, nil, windows.OPEN_EXISTING, 0, 0)
if err != nil {
return windows.InvalidHandle, fmt.Errorf("Open NDIS device failed: %v", err)
}
return handle, nil
}
// GUID returns the GUID of the interface.
func (wintun *Wintun) GUID() windows.GUID {
return wintun.cfgInstanceID
}
// LUID returns the LUID of the interface.
func (wintun *Wintun) LUID() uint64 {
return ((uint64(wintun.luidIndex) & ((1 << 24) - 1)) << 24) | ((uint64(wintun.ifType) & ((1 << 16) - 1)) << 48)
} }