Compare commits
130 Commits
0.0.201905
...
0.0.202001
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05b03c6750 | ||
|
|
caebdfe9d0 | ||
|
|
4fa2ea6a2d | ||
|
|
89dd065e53 | ||
|
|
ddfad453cf | ||
|
|
2b242f9393 | ||
|
|
4cdf805b29 | ||
|
|
f7d0edd2ec | ||
|
|
ffffbbcc8a | ||
|
|
47b02c618b | ||
|
|
fd23c66fcd | ||
|
|
ae492d1b35 | ||
|
|
95fbfccf60 | ||
|
|
c85e4a410f | ||
|
|
1b6c8ddbe8 | ||
|
|
0abb6b668c | ||
|
|
540d01e54a | ||
|
|
f2ea85e9f9 | ||
|
|
222f0f8000 | ||
|
|
1f146a5e7a | ||
|
|
f2501aa6c8 | ||
|
|
cb8d01f58a | ||
|
|
01f8ef4e84 | ||
|
|
70f6c42556 | ||
|
|
bb0b2514c0 | ||
|
|
7c97fdb1e3 | ||
|
|
84b5a4d83d | ||
|
|
4cd06c0925 | ||
|
|
d12eb91f9a | ||
|
|
73d3bd9cd5 | ||
|
|
f3dba4c194 | ||
|
|
7937840f96 | ||
|
|
e4b957183c | ||
|
|
950ca2ba8c | ||
|
|
df2bf34373 | ||
|
|
a12b765784 | ||
|
|
14df9c3e75 | ||
|
|
353f0956bc | ||
|
|
fa7763c268 | ||
|
|
d94bae8348 | ||
|
|
7689d09336 | ||
|
|
69c26dc258 | ||
|
|
e862131d3c | ||
|
|
da28a3e9f3 | ||
|
|
3bf3322b2c | ||
|
|
7305b4ce93 | ||
|
|
26fb615b11 | ||
|
|
7fbb24afaa | ||
|
|
d9008ac35c | ||
|
|
f8198c0428 | ||
|
|
0c540ad60e | ||
|
|
3cedc22d7b | ||
|
|
68fea631d8 | ||
|
|
ef23100a4f | ||
|
|
eb786cd7c1 | ||
|
|
333de75370 | ||
|
|
d20459dc69 | ||
|
|
01786286c1 | ||
|
|
b16dba47a7 | ||
|
|
4be9630ddc | ||
|
|
4e3018a967 | ||
|
|
b4010123f7 | ||
|
|
1ff37e2b07 | ||
|
|
f5e54932e6 | ||
|
|
73698066d1 | ||
|
|
05ece4d167 | ||
|
|
6d78f89557 | ||
|
|
a2249449d6 | ||
|
|
eeeac287ef | ||
|
|
b5a7cbf069 | ||
|
|
50cd522cb0 | ||
|
|
5ba866a5c8 | ||
|
|
2f101fedec | ||
|
|
3341e2d444 | ||
|
|
1b550f6583 | ||
|
|
7bc0e11831 | ||
|
|
31ff9c02fe | ||
|
|
1e39c33ab1 | ||
|
|
6c50fedd8e | ||
|
|
298d759f3e | ||
|
|
4d5819183e | ||
|
|
9ea9a92117 | ||
|
|
2e24e7dcae | ||
|
|
a961aacc9f | ||
|
|
b0cf53b078 | ||
|
|
5c3d333f10 | ||
|
|
d8448f8a02 | ||
|
|
13abbdf14b | ||
|
|
f361e59001 | ||
|
|
b844f1b3cc | ||
|
|
dd8817f50e | ||
|
|
5e6eff81b6 | ||
|
|
c69d026649 | ||
|
|
1f48971a80 | ||
|
|
3371f8dac6 | ||
|
|
41fdbf0971 | ||
|
|
03eee4a778 | ||
|
|
700860f8e6 | ||
|
|
a304f69e0d | ||
|
|
baafe92888 | ||
|
|
a1a97d1e41 | ||
|
|
e924280baa | ||
|
|
bb3f1932fa | ||
|
|
eaf17becfa | ||
|
|
6d8b68c8f3 | ||
|
|
c2ed133df8 | ||
|
|
108c37a056 | ||
|
|
e4b0ef29a1 | ||
|
|
625e445b22 | ||
|
|
85b85e62e5 | ||
|
|
014f736480 | ||
|
|
43a4589043 | ||
|
|
8d76ac8cc4 | ||
|
|
18b6627f33 | ||
|
|
80ef2a42e6 | ||
|
|
da61947ec3 | ||
|
|
d9f995209c | ||
|
|
d0ab883ada | ||
|
|
32912dc778 | ||
|
|
d4034e5f8a | ||
|
|
fbcd995ec1 | ||
|
|
e7e286ba6c | ||
|
|
f70546bc2e | ||
|
|
6a0a3a5406 | ||
|
|
8fdcf5ee30 | ||
|
|
a74a29bc93 | ||
|
|
dc9bbec9db | ||
|
|
a6dbe4f475 | ||
|
|
c718f3940d | ||
|
|
95c70b8032 |
16
Makefile
16
Makefile
@@ -1,30 +1,16 @@
|
|||||||
PREFIX ?= /usr
|
PREFIX ?= /usr
|
||||||
DESTDIR ?=
|
DESTDIR ?=
|
||||||
BINDIR ?= $(PREFIX)/bin
|
BINDIR ?= $(PREFIX)/bin
|
||||||
export GOPATH ?= $(CURDIR)/.gopath
|
|
||||||
export GO111MODULE := on
|
export GO111MODULE := on
|
||||||
|
|
||||||
all: generate-version-and-build
|
all: generate-version-and-build
|
||||||
|
|
||||||
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
|
|
||||||
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
|
|
||||||
else ifeq ($(shell go env GOOS),linux)
|
|
||||||
ireallywantobuildon_linux.go:
|
|
||||||
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
|
|
||||||
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"
|
|
||||||
clean-ireallywantobuildon_linux.go:
|
|
||||||
@rm -f ireallywantobuildon_linux.go
|
|
||||||
.PHONY: clean-ireallywantobuildon_linux.go
|
|
||||||
clean: clean-ireallywantobuildon_linux.go
|
|
||||||
wireguard-go: ireallywantobuildon_linux.go
|
|
||||||
endif
|
|
||||||
|
|
||||||
MAKEFLAGS += --no-print-directory
|
MAKEFLAGS += --no-print-directory
|
||||||
|
|
||||||
generate-version-and-build:
|
generate-version-and-build:
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||||
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
|
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
|
||||||
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
||||||
echo "$$ver" > device/version.go && \
|
echo "$$ver" > device/version.go && \
|
||||||
git update-index --assume-unchanged device/version.go || true
|
git update-index --assume-unchanged device/version.go || true
|
||||||
|
|||||||
@@ -2,8 +2,6 @@
|
|||||||
|
|
||||||
This is an implementation of WireGuard in Go.
|
This is an implementation of WireGuard in Go.
|
||||||
|
|
||||||
***WARNING:*** This is a work in progress and not ready for prime time, with no official "releases" yet. It is extremely rough around the edges and leaves much to be desired. There are bugs and we are not yet in a position to make claims about its security. Beware.
|
|
||||||
|
|
||||||
## Usage
|
## Usage
|
||||||
|
|
||||||
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
|
Most Linux kernel WireGuard users are used to adding an interface with `ip link add wg0 type wireguard`. With wireguard-go, instead simply run:
|
||||||
@@ -20,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
|
|||||||
$ wireguard-go -f wg0
|
$ wireguard-go -f wg0
|
||||||
```
|
```
|
||||||
|
|
||||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
@@ -36,7 +34,7 @@ This runs on macOS using the utun driver. It does not yet support sticky sockets
|
|||||||
|
|
||||||
### Windows
|
### Windows
|
||||||
|
|
||||||
It is currently a work in progress to strip out the beginnings of an experiment done with the OpenVPN tuntap driver and instead port to the new UWP APIs for tunnels. In other words, this does not *yet* work on Windows.
|
This runs on Windows, but you should instead use it from the more [fully featured Windows app](https://git.zx2c4.com/wireguard-windows/about/), which uses this as a module.
|
||||||
|
|
||||||
### FreeBSD
|
### FreeBSD
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,14 @@
|
|||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
|
import "errors"
|
||||||
|
|
||||||
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
|
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
nb, ok := device.net.bind.(*nativeBind)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("no socket exists")
|
||||||
|
}
|
||||||
|
sysconn, err := nb.ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -20,7 +26,11 @@ func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
|
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
nb, ok := device.net.bind.(*nativeBind)
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("no socket exists")
|
||||||
|
}
|
||||||
|
sysconn, err := nb.ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,44 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package device
|
|
||||||
|
|
||||||
import (
|
|
||||||
"golang.org/x/sys/unix"
|
|
||||||
)
|
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, int(interfaceIndex))
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
err2 := sysconn.Control(func(fd uintptr) {
|
|
||||||
err = unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, int(interfaceIndex))
|
|
||||||
})
|
|
||||||
if err2 != nil {
|
|
||||||
return err2
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -7,6 +7,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -17,12 +18,16 @@ const (
|
|||||||
sockoptIPV6_UNICAST_IF = 31
|
sockoptIPV6_UNICAST_IF = 31
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||||
bytes := make([]byte, 4)
|
bytes := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
||||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||||
|
|
||||||
|
if device.net.bind == nil {
|
||||||
|
return errors.New("Bind is not yet initialized")
|
||||||
|
}
|
||||||
|
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -36,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
device.net.bind.(*nativeBind).blackhole4 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -53,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
device.net.bind.(*nativeBind).blackhole6 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,8 +21,10 @@ import (
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
type nativeBind struct {
|
type nativeBind struct {
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
|
blackhole4 bool
|
||||||
|
blackhole6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeEndpoint net.UDPAddr
|
type NativeEndpoint net.UDPAddr
|
||||||
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
|
|||||||
if bind.ipv4 == nil {
|
if bind.ipv4 == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
if bind.blackhole4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
} else {
|
} else {
|
||||||
if bind.ipv6 == nil {
|
if bind.ipv6 == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
if bind.blackhole6 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
* of the sticky-sockets.c example code:
|
* of the sticky-sockets.c example code:
|
||||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
|
||||||
*
|
*
|
||||||
* Currently there is no way to achieve this within the net package:
|
* Currently there is no way to achieve this within the net package:
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
@@ -43,6 +43,7 @@ type IPv6Source struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NativeEndpoint struct {
|
type NativeEndpoint struct {
|
||||||
|
sync.Mutex
|
||||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
||||||
src [unsafe.Sizeof(IPv6Source{})]byte
|
src [unsafe.Sizeof(IPv6Source{})]byte
|
||||||
isV6 bool
|
isV6 bool
|
||||||
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|||||||
|
|
||||||
go bind.routineRouteListener(device)
|
go bind.routineRouteListener(device)
|
||||||
|
|
||||||
// attempt ipv6 bind, update port if succesful
|
// attempt ipv6 bind, update port if successful
|
||||||
|
|
||||||
bind.sock6, newPort, err = create6(port)
|
bind.sock6, newPort, err = create6(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|||||||
port = newPort
|
port = newPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// attempt ipv4 bind, update port if succesful
|
// attempt ipv4 bind, update port if successful
|
||||||
|
|
||||||
bind.sock4, newPort, err = create4(port)
|
bind.sock4, newPort, err = create4(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -391,6 +392,11 @@ func create4(port uint16) (int, uint16, error) {
|
|||||||
return FD_ERR, 0, err
|
return FD_ERR, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sa, err := unix.Getsockname(fd)
|
||||||
|
if err == nil {
|
||||||
|
addr.Port = sa.(*unix.SockaddrInet4).Port
|
||||||
|
}
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
return fd, uint16(addr.Port), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -450,6 +456,11 @@ func create6(port uint16) (int, uint16, error) {
|
|||||||
return FD_ERR, 0, err
|
return FD_ERR, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sa, err := unix.Getsockname(fd)
|
||||||
|
if err == nil {
|
||||||
|
addr.Port = sa.(*unix.SockaddrInet6).Port
|
||||||
|
}
|
||||||
|
|
||||||
return fd, uint16(addr.Port), err
|
return fd, uint16(addr.Port), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end.Lock()
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
end.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -483,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
if err == unix.EINVAL {
|
if err == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||||
|
end.Lock()
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
end.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -512,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
cmsg.pktinfo.Ifindex = 0
|
cmsg.pktinfo.Ifindex = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end.Lock()
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
end.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -523,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
if err == unix.EINVAL {
|
if err == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||||
|
end.Lock()
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
end.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -531,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
|
|
||||||
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// construct message header
|
||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
@@ -563,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|||||||
|
|
||||||
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// construct message header
|
||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
/* Specification constants */
|
/* Specification constants */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
|
RekeyAfterMessages = (1 << 60)
|
||||||
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
||||||
RekeyAfterTime = time.Second * 120
|
RekeyAfterTime = time.Second * 120
|
||||||
RekeyAttemptTime = time.Second * 90
|
RekeyAttemptTime = time.Second * 90
|
||||||
@@ -22,7 +22,7 @@ const (
|
|||||||
RejectAfterTime = time.Second * 180
|
RejectAfterTime = time.Second * 180
|
||||||
KeepaliveTimeout = time.Second * 10
|
KeepaliveTimeout = time.Second * 10
|
||||||
CookieRefreshTime = time.Second * 120
|
CookieRefreshTime = time.Second * 120
|
||||||
HandshakeInitationRate = time.Second / 20
|
HandshakeInitationRate = time.Second / 50
|
||||||
PaddingMultiple = 16
|
PaddingMultiple = 16
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -87,7 +87,7 @@ func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
|
|||||||
st.RLock()
|
st.RLock()
|
||||||
defer st.RUnlock()
|
defer st.RUnlock()
|
||||||
|
|
||||||
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
|
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,7 +124,7 @@ func (st *CookieChecker) CreateReply(
|
|||||||
|
|
||||||
// refresh cookie secret
|
// refresh cookie secret
|
||||||
|
|
||||||
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
|
if time.Since(st.mac2.secretSet) > CookieRefreshTime {
|
||||||
st.RUnlock()
|
st.RUnlock()
|
||||||
st.Lock()
|
st.Lock()
|
||||||
_, err := rand.Read(st.mac2.secret[:])
|
_, err := rand.Read(st.mac2.secret[:])
|
||||||
@@ -239,7 +239,7 @@ func (st *CookieGenerator) AddMacs(msg []byte) {
|
|||||||
|
|
||||||
// set mac2
|
// set mac2
|
||||||
|
|
||||||
if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime {
|
if time.Since(st.mac2.cookieSet) > CookieRefreshTime {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -86,7 +86,7 @@ type Device struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
tun struct {
|
tun struct {
|
||||||
device tun.TUNDevice
|
device tun.Device
|
||||||
mtu int32
|
mtu int32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -133,6 +133,7 @@ func deviceUpdateState(device *Device) {
|
|||||||
switch newIsUp {
|
switch newIsUp {
|
||||||
case true:
|
case true:
|
||||||
if err := device.BindUpdate(); err != nil {
|
if err := device.BindUpdate(); err != nil {
|
||||||
|
device.log.Error.Printf("Unable to update bind: %v\n", err)
|
||||||
device.isUp.Set(false)
|
device.isUp.Set(false)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -200,18 +201,22 @@ func (device *Device) IsUnderLoad() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||||
|
|
||||||
// lock required resources
|
// lock required resources
|
||||||
|
|
||||||
device.staticIdentity.Lock()
|
device.staticIdentity.Lock()
|
||||||
defer device.staticIdentity.Unlock()
|
defer device.staticIdentity.Unlock()
|
||||||
|
|
||||||
|
if sk.Equals(device.staticIdentity.privateKey) {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
device.peers.Lock()
|
device.peers.Lock()
|
||||||
defer device.peers.Unlock()
|
defer device.peers.Unlock()
|
||||||
|
|
||||||
|
lockedPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for _, peer := range device.peers.keyMap {
|
for _, peer := range device.peers.keyMap {
|
||||||
peer.handshake.mutex.RLock()
|
peer.handshake.mutex.RLock()
|
||||||
defer peer.handshake.mutex.RUnlock()
|
lockedPeers = append(lockedPeers, peer)
|
||||||
}
|
}
|
||||||
|
|
||||||
// remove peers with matching public keys
|
// remove peers with matching public keys
|
||||||
@@ -233,8 +238,8 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
|
|
||||||
rmKey := device.staticIdentity.privateKey.IsZero()
|
rmKey := device.staticIdentity.privateKey.IsZero()
|
||||||
|
|
||||||
|
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||||
for key, peer := range device.peers.keyMap {
|
for key, peer := range device.peers.keyMap {
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
|
|
||||||
if rmKey {
|
if rmKey {
|
||||||
@@ -245,13 +250,22 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
|||||||
|
|
||||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||||
unsafeRemovePeer(device, peer, key)
|
unsafeRemovePeer(device, peer, key)
|
||||||
|
} else {
|
||||||
|
expiredPeers = append(expiredPeers, peer)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for _, peer := range lockedPeers {
|
||||||
|
peer.handshake.mutex.RUnlock()
|
||||||
|
}
|
||||||
|
for _, peer := range expiredPeers {
|
||||||
|
peer.ExpireCurrentKeypairs()
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
|
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||||
device := new(Device)
|
device := new(Device)
|
||||||
|
|
||||||
device.isUp.Set(false)
|
device.isUp.Set(false)
|
||||||
@@ -323,7 +337,6 @@ func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
|
|||||||
func (device *Device) RemovePeer(key NoisePublicKey) {
|
func (device *Device) RemovePeer(key NoisePublicKey) {
|
||||||
device.peers.Lock()
|
device.peers.Lock()
|
||||||
defer device.peers.Unlock()
|
defer device.peers.Unlock()
|
||||||
|
|
||||||
// stop peer and remove from routing
|
// stop peer and remove from routing
|
||||||
|
|
||||||
peer, ok := device.peers.keyMap[key]
|
peer, ok := device.peers.keyMap[key]
|
||||||
@@ -395,3 +408,20 @@ func (device *Device) Close() {
|
|||||||
func (device *Device) Wait() chan struct{} {
|
func (device *Device) Wait() chan struct{} {
|
||||||
return device.signals.stop
|
return device.signals.stop
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||||
|
if device.isClosed.Get() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
device.peers.RLock()
|
||||||
|
for _, peer := range device.peers.keyMap {
|
||||||
|
peer.keypairs.RLock()
|
||||||
|
sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now())
|
||||||
|
peer.keypairs.RUnlock()
|
||||||
|
if sendKeepalive {
|
||||||
|
peer.SendKeepalive()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
device.peers.RUnlock()
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,44 +5,234 @@
|
|||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
/* Create two device instances and simulate full WireGuard interaction
|
import (
|
||||||
* without network dependencies
|
"bufio"
|
||||||
*/
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
import "testing"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
func TestDevice(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
|
// TODO(crawshaw): pick unused ports on localhost
|
||||||
// prepare tun devices for generating traffic
|
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
|
||||||
|
listen_port=53511
|
||||||
tun1, err := CreateDummyTUN("tun1")
|
replace_peers=true
|
||||||
if err != nil {
|
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
|
||||||
t.Error("failed to create tun:", err.Error())
|
protocol_version=1
|
||||||
|
replace_allowed_ips=true
|
||||||
|
allowed_ip=1.0.0.2/32
|
||||||
|
endpoint=127.0.0.1:53512`
|
||||||
|
tun1 := NewChannelTUN()
|
||||||
|
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
|
||||||
|
dev1.Up()
|
||||||
|
defer dev1.Close()
|
||||||
|
if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tun2, err := CreateDummyTUN("tun2")
|
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
|
||||||
if err != nil {
|
listen_port=53512
|
||||||
t.Error("failed to create tun:", err.Error())
|
replace_peers=true
|
||||||
|
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
|
||||||
|
protocol_version=1
|
||||||
|
replace_allowed_ips=true
|
||||||
|
allowed_ip=1.0.0.1/32
|
||||||
|
endpoint=127.0.0.1:53511`
|
||||||
|
tun2 := NewChannelTUN()
|
||||||
|
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
|
||||||
|
dev2.Up()
|
||||||
|
defer dev2.Close()
|
||||||
|
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = tun1
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
_ = tun2
|
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
||||||
|
tun2.Outbound <- msg2to1
|
||||||
// prepare endpoints
|
select {
|
||||||
|
case msgRecv := <-tun1.Inbound:
|
||||||
end1, err := CreateDummyEndpoint()
|
if !bytes.Equal(msg2to1, msgRecv) {
|
||||||
if err != nil {
|
t.Error("ping did not transit correctly")
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
}
|
||||||
}
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Error("ping did not transit")
|
||||||
end2, err := CreateDummyEndpoint()
|
}
|
||||||
if err != nil {
|
})
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = end1
|
|
||||||
_ = end2
|
|
||||||
|
|
||||||
// create binds
|
|
||||||
|
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
||||||
|
tun1.Outbound <- msg1to2
|
||||||
|
select {
|
||||||
|
case msgRecv := <-tun2.Inbound:
|
||||||
|
if !bytes.Equal(msg1to2, msgRecv) {
|
||||||
|
t.Error("return ping did not transit correctly")
|
||||||
|
}
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Error("return ping did not transit")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ping(dst, src net.IP) []byte {
|
||||||
|
localPort := uint16(1337)
|
||||||
|
seq := uint16(0)
|
||||||
|
|
||||||
|
payload := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint16(payload[0:], localPort)
|
||||||
|
binary.BigEndian.PutUint16(payload[2:], seq)
|
||||||
|
|
||||||
|
return genICMPv4(payload, dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
||||||
|
func checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
v := uint32(initial)
|
||||||
|
for i := 0; i < len(buf)-1; i += 2 {
|
||||||
|
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||||
|
}
|
||||||
|
if len(buf)%2 == 1 {
|
||||||
|
v += uint32(buf[len(buf)-1]) << 8
|
||||||
|
}
|
||||||
|
for v > 0xffff {
|
||||||
|
v = (v >> 16) + (v & 0xffff)
|
||||||
|
}
|
||||||
|
return ^uint16(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
||||||
|
const (
|
||||||
|
icmpv4ProtocolNumber = 1
|
||||||
|
icmpv4Echo = 8
|
||||||
|
icmpv4ChecksumOffset = 2
|
||||||
|
icmpv4Size = 8
|
||||||
|
ipv4Size = 20
|
||||||
|
ipv4TotalLenOffset = 2
|
||||||
|
ipv4ChecksumOffset = 10
|
||||||
|
ttl = 65
|
||||||
|
)
|
||||||
|
|
||||||
|
hdr := make([]byte, ipv4Size+icmpv4Size)
|
||||||
|
|
||||||
|
ip := hdr[0:ipv4Size]
|
||||||
|
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc792
|
||||||
|
icmpv4[0] = icmpv4Echo // type
|
||||||
|
icmpv4[1] = 0 // code
|
||||||
|
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
||||||
|
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc760 section 3.1
|
||||||
|
length := uint16(len(hdr) + len(payload))
|
||||||
|
ip[0] = (4 << 4) | (ipv4Size / 4)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||||
|
ip[8] = ttl
|
||||||
|
ip[9] = icmpv4ProtocolNumber
|
||||||
|
copy(ip[12:], src.To4())
|
||||||
|
copy(ip[16:], dst.To4())
|
||||||
|
chksum = ^checksum(ip[:], 0)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
var v []byte
|
||||||
|
v = append(v, hdr...)
|
||||||
|
v = append(v, payload...)
|
||||||
|
return []byte(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
||||||
|
type ChannelTUN struct {
|
||||||
|
Inbound chan []byte // incoming packets, closed on TUN close
|
||||||
|
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
events chan tun.Event
|
||||||
|
tun chTun
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChannelTUN() *ChannelTUN {
|
||||||
|
c := &ChannelTUN{
|
||||||
|
Inbound: make(chan []byte),
|
||||||
|
Outbound: make(chan []byte),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
events: make(chan tun.Event, 1),
|
||||||
|
}
|
||||||
|
c.tun.c = c
|
||||||
|
c.events <- tun.EventUp
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelTUN) TUN() tun.Device {
|
||||||
|
return &c.tun
|
||||||
|
}
|
||||||
|
|
||||||
|
type chTun struct {
|
||||||
|
c *ChannelTUN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *chTun) File() *os.File { return nil }
|
||||||
|
|
||||||
|
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case msg := <-t.c.Outbound:
|
||||||
|
return copy(data[offset:], msg), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is called by the wireguard device to deliver a packet for routing.
|
||||||
|
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||||
|
if offset == -1 {
|
||||||
|
close(t.c.closed)
|
||||||
|
close(t.c.events)
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
msg := make([]byte, len(data)-offset)
|
||||||
|
copy(msg, data[offset:])
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case t.c.Inbound <- msg:
|
||||||
|
return len(data) - offset, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *chTun) Flush() error { return nil }
|
||||||
|
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||||
|
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||||
|
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
||||||
|
func (t *chTun) Close() error {
|
||||||
|
t.Write(nil, -1)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertNil(t *testing.T, err error) {
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func assertEqual(t *testing.T, a, b []byte) {
|
||||||
|
if !bytes.Equal(a, b) {
|
||||||
|
t.Fatal(a, "!=", b)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randDevice(t *testing.T) *Device {
|
||||||
|
sk, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tun := newDummyTUN("dummy")
|
||||||
|
logger := NewLogger(LogLevelError, "")
|
||||||
|
device := NewDevice(tun, logger)
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return device
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,6 @@ func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
|
|||||||
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
|
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
|
||||||
HMAC1(t0, key, input)
|
HMAC1(t0, key, input)
|
||||||
HMAC1(t0, t0[:], []byte{0x1})
|
HMAC1(t0, t0[:], []byte{0x1})
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
||||||
@@ -51,7 +50,6 @@ func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
|
|||||||
HMAC1(t0, prk[:], []byte{0x1})
|
HMAC1(t0, prk[:], []byte{0x1})
|
||||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||||
setZero(prk[:])
|
setZero(prk[:])
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
||||||
@@ -61,7 +59,6 @@ func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
|
|||||||
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
HMAC2(t1, prk[:], t0[:], []byte{0x2})
|
||||||
HMAC2(t2, prk[:], t1[:], []byte{0x3})
|
HMAC2(t2, prk[:], t1[:], []byte{0x3})
|
||||||
setZero(prk[:])
|
setZero(prk[:])
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func isZero(val []byte) bool {
|
func isZero(val []byte) bool {
|
||||||
|
|||||||
@@ -39,13 +39,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageInitiationSize = 148 // size of handshake initation message
|
MessageInitiationSize = 148 // size of handshake initiation message
|
||||||
MessageResponseSize = 92 // size of response message
|
MessageResponseSize = 92 // size of response message
|
||||||
MessageCookieReplySize = 64 // size of cookie reply message
|
MessageCookieReplySize = 64 // size of cookie reply message
|
||||||
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
|
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
||||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -301,7 +301,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
|
|
||||||
var ok bool
|
var ok bool
|
||||||
ok = timestamp.After(handshake.lastTimestamp)
|
ok = timestamp.After(handshake.lastTimestamp)
|
||||||
ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate
|
ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
|
||||||
handshake.mutex.RUnlock()
|
handshake.mutex.RUnlock()
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
@@ -315,8 +315,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
handshake.lastTimestamp = timestamp
|
if timestamp.After(handshake.lastTimestamp) {
|
||||||
handshake.lastInitiationConsumption = time.Now()
|
handshake.lastTimestamp = timestamp
|
||||||
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(handshake.lastInitiationConsumption) {
|
||||||
|
handshake.lastInitiationConsumption = now
|
||||||
|
}
|
||||||
handshake.state = HandshakeInitiationConsumed
|
handshake.state = HandshakeInitiationConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -67,7 +68,6 @@ type Peer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||||
|
|
||||||
if device.isClosed.Get() {
|
if device.isClosed.Get() {
|
||||||
return nil, errors.New("device closed")
|
return nil, errors.New("device closed")
|
||||||
}
|
}
|
||||||
@@ -102,20 +102,28 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
|||||||
if ok {
|
if ok {
|
||||||
return nil, errors.New("adding existing peer")
|
return nil, errors.New("adding existing peer")
|
||||||
}
|
}
|
||||||
device.peers.keyMap[pk] = peer
|
|
||||||
|
|
||||||
// pre-compute DH
|
// pre-compute DH
|
||||||
|
|
||||||
handshake := &peer.handshake
|
handshake := &peer.handshake
|
||||||
handshake.mutex.Lock()
|
handshake.mutex.Lock()
|
||||||
handshake.remoteStatic = pk
|
|
||||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||||
|
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
|
||||||
|
handshake.remoteStatic = pk
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|
||||||
// reset endpoint
|
// reset endpoint
|
||||||
|
|
||||||
peer.endpoint = nil
|
peer.endpoint = nil
|
||||||
|
|
||||||
|
// conditionally add
|
||||||
|
|
||||||
|
if !ssIsZero {
|
||||||
|
device.peers.keyMap[pk] = peer
|
||||||
|
} else {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
// start peer
|
// start peer
|
||||||
|
|
||||||
if peer.device.isUp.Get() {
|
if peer.device.isUp.Get() {
|
||||||
@@ -140,7 +148,11 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
|||||||
return errors.New("no known endpoint for peer")
|
return errors.New("no known endpoint for peer")
|
||||||
}
|
}
|
||||||
|
|
||||||
return peer.device.net.bind.Send(buffer, peer.endpoint)
|
err := peer.device.net.bind.Send(buffer, peer.endpoint)
|
||||||
|
if err == nil {
|
||||||
|
atomic.AddUint64(&peer.stats.txBytes, uint64(len(buffer)))
|
||||||
|
}
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) String() string {
|
func (peer *Peer) String() string {
|
||||||
@@ -227,6 +239,25 @@ func (peer *Peer) ZeroAndFlushAll() {
|
|||||||
peer.FlushNonceQueue()
|
peer.FlushNonceQueue()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (peer *Peer) ExpireCurrentKeypairs() {
|
||||||
|
handshake := &peer.handshake
|
||||||
|
handshake.mutex.Lock()
|
||||||
|
peer.device.indexTable.Delete(handshake.localIndex)
|
||||||
|
handshake.Clear()
|
||||||
|
handshake.mutex.Unlock()
|
||||||
|
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||||
|
|
||||||
|
keypairs := &peer.keypairs
|
||||||
|
keypairs.Lock()
|
||||||
|
if keypairs.current != nil {
|
||||||
|
keypairs.current.sendNonce = RejectAfterMessages
|
||||||
|
}
|
||||||
|
if keypairs.next != nil {
|
||||||
|
keypairs.next.sendNonce = RejectAfterMessages
|
||||||
|
}
|
||||||
|
keypairs.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
func (peer *Peer) Stop() {
|
func (peer *Peer) Stop() {
|
||||||
|
|
||||||
// prevent simultaneous start/stop operations
|
// prevent simultaneous start/stop operations
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
keypair := peer.keypairs.Current()
|
keypair := peer.keypairs.Current()
|
||||||
if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
if keypair != nil && keypair.isInitiator && time.Since(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
|
||||||
peer.timers.sentLastMinuteHandshake.Set(true)
|
peer.timers.sentLastMinuteHandshake.Set(true)
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
@@ -427,6 +427,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
logDebug.Println(peer, "- Received handshake initiation")
|
logDebug.Println(peer, "- Received handshake initiation")
|
||||||
|
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||||
|
|
||||||
peer.SendHandshakeResponse()
|
peer.SendHandshakeResponse()
|
||||||
|
|
||||||
@@ -457,6 +458,7 @@ func (device *Device) RoutineHandshake() {
|
|||||||
peer.SetEndpointFromPacket(elem.endpoint)
|
peer.SetEndpointFromPacket(elem.endpoint)
|
||||||
|
|
||||||
logDebug.Println(peer, "- Received handshake response")
|
logDebug.Println(peer, "- Received handshake response")
|
||||||
|
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
||||||
|
|
||||||
// update timers
|
// update timers
|
||||||
|
|
||||||
@@ -483,33 +485,6 @@ func (device *Device) RoutineHandshake() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) elementStopOrFlush(shouldFlush *bool) (stop bool, elemOk bool, elem *QueueInboundElement) {
|
|
||||||
if !*shouldFlush {
|
|
||||||
select {
|
|
||||||
case <-peer.routines.stop:
|
|
||||||
stop = true
|
|
||||||
return
|
|
||||||
case elem, elemOk = <-peer.queue.inbound:
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
select {
|
|
||||||
case <-peer.routines.stop:
|
|
||||||
stop = true
|
|
||||||
return
|
|
||||||
case elem, elemOk = <-peer.queue.inbound:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
*shouldFlush = false
|
|
||||||
err := peer.device.tun.device.Flush()
|
|
||||||
if err != nil {
|
|
||||||
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
|
|
||||||
}
|
|
||||||
return peer.elementStopOrFlush(shouldFlush)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (peer *Peer) RoutineSequentialReceiver() {
|
func (peer *Peer) RoutineSequentialReceiver() {
|
||||||
|
|
||||||
device := peer.device
|
device := peer.device
|
||||||
@@ -518,10 +493,6 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
logDebug := device.log.Debug
|
logDebug := device.log.Debug
|
||||||
|
|
||||||
var elem *QueueInboundElement
|
var elem *QueueInboundElement
|
||||||
var ok bool
|
|
||||||
var stop bool
|
|
||||||
|
|
||||||
shouldFlush := false
|
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
logDebug.Println(peer, "- Routine: sequential receiver - stopped")
|
logDebug.Println(peer, "- Routine: sequential receiver - stopped")
|
||||||
@@ -547,9 +518,14 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
elem = nil
|
elem = nil
|
||||||
}
|
}
|
||||||
|
|
||||||
stop, ok, elem = peer.elementStopOrFlush(&shouldFlush)
|
var elemOk bool
|
||||||
if stop || !ok {
|
select {
|
||||||
|
case <-peer.routines.stop:
|
||||||
return
|
return
|
||||||
|
case elem, elemOk = <-peer.queue.inbound:
|
||||||
|
if !elemOk {
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// wait for decryption
|
// wait for decryption
|
||||||
@@ -581,6 +557,7 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
peer.keepKeyFreshReceiving()
|
peer.keepKeyFreshReceiving()
|
||||||
peer.timersAnyAuthenticatedPacketTraversal()
|
peer.timersAnyAuthenticatedPacketTraversal()
|
||||||
peer.timersAnyAuthenticatedPacketReceived()
|
peer.timersAnyAuthenticatedPacketReceived()
|
||||||
|
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
|
||||||
|
|
||||||
// check for keepalive
|
// check for keepalive
|
||||||
|
|
||||||
@@ -642,8 +619,8 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
|
||||||
if device.allowedips.LookupIPv6(src) != peer {
|
if device.allowedips.LookupIPv6(src) != peer {
|
||||||
logInfo.Println(
|
logInfo.Println(
|
||||||
|
"IPv6 packet with disallowed source address from",
|
||||||
peer,
|
peer,
|
||||||
"sent packet with disallowed IPv6 source",
|
|
||||||
)
|
)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -656,10 +633,12 @@ func (peer *Peer) RoutineSequentialReceiver() {
|
|||||||
// write to tun device
|
// write to tun device
|
||||||
|
|
||||||
offset := MessageTransportOffsetContent
|
offset := MessageTransportOffsetContent
|
||||||
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
|
|
||||||
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
|
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
|
||||||
if err == nil {
|
if len(peer.queue.inbound) == 0 {
|
||||||
shouldFlush = true
|
err = device.tun.device.Flush()
|
||||||
|
if err != nil {
|
||||||
|
peer.device.log.Error.Printf("Unable to flush packets: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if err != nil && !device.isClosed.Get() {
|
if err != nil && !device.isClosed.Get() {
|
||||||
logError.Println("Failed to write packet to TUN device:", err)
|
logError.Println("Failed to write packet to TUN device:", err)
|
||||||
|
|||||||
@@ -129,14 +129,14 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
peer.handshake.mutex.RLock()
|
peer.handshake.mutex.RLock()
|
||||||
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
||||||
peer.handshake.mutex.RUnlock()
|
peer.handshake.mutex.RUnlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
peer.handshake.mutex.RUnlock()
|
peer.handshake.mutex.RUnlock()
|
||||||
|
|
||||||
peer.handshake.mutex.Lock()
|
peer.handshake.mutex.Lock()
|
||||||
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
if time.Since(peer.handshake.lastSentHandshake) < RekeyTimeout {
|
||||||
peer.handshake.mutex.Unlock()
|
peer.handshake.mutex.Unlock()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
|
|||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buff[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
||||||
if err != nil {
|
return nil
|
||||||
device.log.Error.Println("Failed to send cookie reply:", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) keepKeyFreshSending() {
|
func (peer *Peer) keepKeyFreshSending() {
|
||||||
@@ -232,7 +229,7 @@ func (peer *Peer) keepKeyFreshSending() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
nonce := atomic.LoadUint64(&keypair.sendNonce)
|
nonce := atomic.LoadUint64(&keypair.sendNonce)
|
||||||
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) {
|
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Since(keypair.created) > RekeyAfterTime) {
|
||||||
peer.SendHandshakeInitiation(false)
|
peer.SendHandshakeInitiation(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -390,7 +387,7 @@ func (peer *Peer) RoutineNonce() {
|
|||||||
|
|
||||||
keypair = peer.keypairs.Current()
|
keypair = peer.keypairs.Current()
|
||||||
if keypair != nil && keypair.sendNonce < RejectAfterMessages {
|
if keypair != nil && keypair.sendNonce < RejectAfterMessages {
|
||||||
if time.Now().Sub(keypair.created) < RejectAfterTime {
|
if time.Since(keypair.created) < RejectAfterTime {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -600,7 +597,6 @@ func (peer *Peer) RoutineSequentialSender() {
|
|||||||
|
|
||||||
// send message and return buffer to pool
|
// send message and return buffer to pool
|
||||||
|
|
||||||
length := uint64(len(elem.packet))
|
|
||||||
err := peer.SendBuffer(elem.packet)
|
err := peer.SendBuffer(elem.packet)
|
||||||
if len(elem.packet) != MessageKeepaliveSize {
|
if len(elem.packet) != MessageKeepaliveSize {
|
||||||
peer.timersDataSent()
|
peer.timersDataSent()
|
||||||
@@ -611,7 +607,6 @@ func (peer *Peer) RoutineSequentialSender() {
|
|||||||
logError.Println(peer, "- Failed to send data packet", err)
|
logError.Println(peer, "- Failed to send data packet", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
atomic.AddUint64(&peer.stats.txBytes, length)
|
|
||||||
|
|
||||||
peer.keepKeyFreshSending()
|
peer.keepKeyFreshSending()
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -147,7 +147,7 @@ func expiredPersistentKeepalive(peer *Peer) {
|
|||||||
/* Should be called after an authenticated data packet is sent. */
|
/* Should be called after an authenticated data packet is sent. */
|
||||||
func (peer *Peer) timersDataSent() {
|
func (peer *Peer) timersDataSent() {
|
||||||
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
|
||||||
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
|
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -23,7 +23,7 @@ func (device *Device) RoutineTUNEventReader() {
|
|||||||
device.state.starting.Done()
|
device.state.starting.Done()
|
||||||
|
|
||||||
for event := range device.tun.device.Events() {
|
for event := range device.tun.device.Events() {
|
||||||
if event&tun.TUNEventMTUUpdate != 0 {
|
if event&tun.EventMTUUpdate != 0 {
|
||||||
mtu, err := device.tun.device.MTU()
|
mtu, err := device.tun.device.MTU()
|
||||||
old := atomic.LoadInt32(&device.tun.mtu)
|
old := atomic.LoadInt32(&device.tun.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,13 +38,13 @@ func (device *Device) RoutineTUNEventReader() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if event&tun.TUNEventUp != 0 && !setUp {
|
if event&tun.EventUp != 0 && !setUp {
|
||||||
logInfo.Println("Interface set up")
|
logInfo.Println("Interface set up")
|
||||||
setUp = true
|
setUp = true
|
||||||
device.Up()
|
device.Up()
|
||||||
}
|
}
|
||||||
|
|
||||||
if event&tun.TUNEventDown != 0 && setUp {
|
if event&tun.EventDown != 0 && setUp {
|
||||||
logInfo.Println("Interface set down")
|
logInfo.Println("Interface set down")
|
||||||
setUp = false
|
setUp = false
|
||||||
device.Down()
|
device.Down()
|
||||||
|
|||||||
56
device/tun_test.go
Normal file
56
device/tun_test.go
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package device
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
|
)
|
||||||
|
|
||||||
|
// newDummyTUN creates a dummy TUN device with the specified name.
|
||||||
|
func newDummyTUN(name string) tun.Device {
|
||||||
|
return &dummyTUN{
|
||||||
|
name: name,
|
||||||
|
packets: make(chan []byte, 100),
|
||||||
|
events: make(chan tun.Event, 10),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// A dummyTUN is a tun.Device which is used in unit tests.
|
||||||
|
type dummyTUN struct {
|
||||||
|
name string
|
||||||
|
mtu int
|
||||||
|
packets chan []byte
|
||||||
|
events chan tun.Event
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dummyTUN) Events() chan tun.Event { return d.events }
|
||||||
|
func (*dummyTUN) File() *os.File { return nil }
|
||||||
|
func (*dummyTUN) Flush() error { return nil }
|
||||||
|
func (d *dummyTUN) MTU() (int, error) { return d.mtu, nil }
|
||||||
|
func (d *dummyTUN) Name() (string, error) { return d.name, nil }
|
||||||
|
|
||||||
|
func (d *dummyTUN) Close() error {
|
||||||
|
close(d.events)
|
||||||
|
close(d.packets)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dummyTUN) Read(b []byte, offset int) (int, error) {
|
||||||
|
buf, ok := <-d.packets
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("device closed")
|
||||||
|
}
|
||||||
|
copy(b[offset:], buf)
|
||||||
|
return len(buf), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *dummyTUN) Write(b []byte, offset int) (int, error) {
|
||||||
|
d.packets <- b[offset:]
|
||||||
|
return len(b), nil
|
||||||
|
}
|
||||||
@@ -113,6 +113,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
var peer *Peer
|
var peer *Peer
|
||||||
|
|
||||||
dummy := false
|
dummy := false
|
||||||
|
createdNewPeer := false
|
||||||
deviceConfig := true
|
deviceConfig := true
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
@@ -237,13 +238,33 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
peer = device.LookupPeer(publicKey)
|
peer = device.LookupPeer(publicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer == nil {
|
createdNewPeer = peer == nil
|
||||||
|
if createdNewPeer {
|
||||||
peer, err = device.NewPeer(publicKey)
|
peer, err = device.NewPeer(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create new peer:", err)
|
logError.Println("Failed to create new peer:", err)
|
||||||
return &IPCError{ipc.IpcErrorInvalid}
|
return &IPCError{ipc.IpcErrorInvalid}
|
||||||
}
|
}
|
||||||
logDebug.Println(peer, "- UAPI: Created")
|
if peer == nil {
|
||||||
|
dummy = true
|
||||||
|
peer = &Peer{}
|
||||||
|
} else {
|
||||||
|
logDebug.Println(peer, "- UAPI: Created")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case "update_only":
|
||||||
|
|
||||||
|
// allow disabling of creation
|
||||||
|
|
||||||
|
if value != "true" {
|
||||||
|
logError.Println("Failed to set update only, invalid value:", value)
|
||||||
|
return &IPCError{ipc.IpcErrorInvalid}
|
||||||
|
}
|
||||||
|
if createdNewPeer && !dummy {
|
||||||
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
peer = &Peer{}
|
||||||
|
dummy = true
|
||||||
}
|
}
|
||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
const WireGuardGoVersion = "0.0.20190517"
|
const WireGuardGoVersion = "0.0.20200121"
|
||||||
|
|||||||
@@ -1,15 +0,0 @@
|
|||||||
// +build !android
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
const DoNotUseThisProgramOnLinux = UseTheKernelModuleInstead
|
|
||||||
|
|
||||||
// --------------------------------------------------------
|
|
||||||
// Do not use this on Linux. Instead use the kernel module.
|
|
||||||
// See wireguard.com/install for more information.
|
|
||||||
// --------------------------------------------------------
|
|
||||||
11
go.mod
11
go.mod
@@ -3,11 +3,8 @@ module golang.zx2c4.com/wireguard
|
|||||||
go 1.12
|
go 1.12
|
||||||
|
|
||||||
require (
|
require (
|
||||||
github.com/Microsoft/go-winio v0.4.12
|
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
|
||||||
github.com/pkg/errors v0.8.1 // indirect
|
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
|
||||||
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734
|
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
|
||||||
golang.org/x/net v0.0.0-20190502183928-7f726cade0ab
|
golang.org/x/text v0.3.2
|
||||||
golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82
|
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/Microsoft/go-winio => golang.zx2c4.com/wireguard/windows v0.0.0-20190429060359-b01600290cd4
|
|
||||||
|
|||||||
19
go.sum
19
go.sum
@@ -1,15 +1,14 @@
|
|||||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
|
||||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
|
||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo=
|
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
|
||||||
golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190502183928-7f726cade0ab h1:9RfW3ktsOZxgo9YNbBAjq1FWzc/igwEcUzZz8IXgSbk=
|
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
|
||||||
golang.org/x/net v0.0.0-20190502183928-7f726cade0ab/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82 h1:vsphBvatvfbhlb4PO1BYSr9dzugGxJ/SQHoNufZJq1w=
|
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
|
||||||
golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.zx2c4.com/wireguard/windows v0.0.0-20190429060359-b01600290cd4 h1:wueYNew2pMLl/LcKqX4PAzc+zV4suK9+DJaZ8yIEHkM=
|
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||||
golang.zx2c4.com/wireguard/windows v0.0.0-20190429060359-b01600290cd4/go.mod h1:Y+FYqVFaQO6a+1uigm0N0GiuaZrLEaBxEiJ8tfH9sMQ=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
|
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||||
|
|||||||
@@ -8,10 +8,12 @@ package ipc
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
"github.com/Microsoft/go-winio"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: replace these with actual standard windows error numbers from the win package
|
// TODO: replace these with actual standard windows error numbers from the win package
|
||||||
const (
|
const (
|
||||||
IpcErrorIO = -int64(5)
|
IpcErrorIO = -int64(5)
|
||||||
IpcErrorProtocol = -int64(71)
|
IpcErrorProtocol = -int64(71)
|
||||||
@@ -47,22 +49,22 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
return l.listener.Addr()
|
return l.listener.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
func GetSystemSecurityDescriptor() string {
|
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
//
|
|
||||||
// SDDL encoded.
|
func init() {
|
||||||
//
|
var err error
|
||||||
// (system = SECURITY_NT_AUTHORITY | SECURITY_LOCAL_SYSTEM_RID)
|
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
||||||
// owner: system
|
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
||||||
// grant: GENERIC_ALL to system
|
if err != nil {
|
||||||
//
|
panic(err)
|
||||||
return "O:SYD:(A;;GA;;;SY)"
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
config := winio.PipeConfig{
|
config := winpipe.PipeConfig{
|
||||||
SecurityDescriptor: GetSystemSecurityDescriptor(),
|
SecurityDescriptor: UAPISecurityDescriptor,
|
||||||
}
|
}
|
||||||
listener, err := winio.ListenPipe("\\\\.\\pipe\\WireGuard\\"+name, &config)
|
listener, err := winpipe.ListenPipe(`\\.\pipe\ProtectedPrefix\Administrators\WireGuard\`+name, &config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
323
ipc/winpipe/file.go
Normal file
323
ipc/winpipe/file.go
Normal file
@@ -0,0 +1,323 @@
|
|||||||
|
// +build windows
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2005 Microsoft
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
package winpipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
|
||||||
|
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
|
||||||
|
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||||
|
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||||
|
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||||
|
|
||||||
|
type atomicBool int32
|
||||||
|
|
||||||
|
func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
|
||||||
|
func (b *atomicBool) setFalse() { atomic.StoreInt32((*int32)(b), 0) }
|
||||||
|
func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
|
||||||
|
func (b *atomicBool) swap(new bool) bool {
|
||||||
|
var newInt int32
|
||||||
|
if new {
|
||||||
|
newInt = 1
|
||||||
|
}
|
||||||
|
return atomic.SwapInt32((*int32)(b), newInt) == 1
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS = 1
|
||||||
|
cFILE_SKIP_SET_EVENT_ON_HANDLE = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrFileClosed = errors.New("file has already been closed")
|
||||||
|
ErrTimeout = &timeoutError{}
|
||||||
|
)
|
||||||
|
|
||||||
|
type timeoutError struct{}
|
||||||
|
|
||||||
|
func (e *timeoutError) Error() string { return "i/o timeout" }
|
||||||
|
func (e *timeoutError) Timeout() bool { return true }
|
||||||
|
func (e *timeoutError) Temporary() bool { return true }
|
||||||
|
|
||||||
|
type timeoutChan chan struct{}
|
||||||
|
|
||||||
|
var ioInitOnce sync.Once
|
||||||
|
var ioCompletionPort windows.Handle
|
||||||
|
|
||||||
|
// ioResult contains the result of an asynchronous IO operation
|
||||||
|
type ioResult struct {
|
||||||
|
bytes uint32
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioOperation represents an outstanding asynchronous Win32 IO
|
||||||
|
type ioOperation struct {
|
||||||
|
o windows.Overlapped
|
||||||
|
ch chan ioResult
|
||||||
|
}
|
||||||
|
|
||||||
|
func initIo() {
|
||||||
|
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
ioCompletionPort = h
|
||||||
|
go ioCompletionProcessor(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||||
|
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||||
|
type win32File struct {
|
||||||
|
handle windows.Handle
|
||||||
|
wg sync.WaitGroup
|
||||||
|
wgLock sync.RWMutex
|
||||||
|
closing atomicBool
|
||||||
|
socket bool
|
||||||
|
readDeadline deadlineHandler
|
||||||
|
writeDeadline deadlineHandler
|
||||||
|
}
|
||||||
|
|
||||||
|
type deadlineHandler struct {
|
||||||
|
setLock sync.Mutex
|
||||||
|
channel timeoutChan
|
||||||
|
channelLock sync.RWMutex
|
||||||
|
timer *time.Timer
|
||||||
|
timedout atomicBool
|
||||||
|
}
|
||||||
|
|
||||||
|
// makeWin32File makes a new win32File from an existing file handle
|
||||||
|
func makeWin32File(h windows.Handle) (*win32File, error) {
|
||||||
|
f := &win32File{handle: h}
|
||||||
|
ioInitOnce.Do(initIo)
|
||||||
|
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = setFileCompletionNotificationModes(h, cFILE_SKIP_COMPLETION_PORT_ON_SUCCESS|cFILE_SKIP_SET_EVENT_ON_HANDLE)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f.readDeadline.channel = make(timeoutChan)
|
||||||
|
f.writeDeadline.channel = make(timeoutChan)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
|
||||||
|
return makeWin32File(h)
|
||||||
|
}
|
||||||
|
|
||||||
|
// closeHandle closes the resources associated with a Win32 handle
|
||||||
|
func (f *win32File) closeHandle() {
|
||||||
|
f.wgLock.Lock()
|
||||||
|
// Atomically set that we are closing, releasing the resources only once.
|
||||||
|
if !f.closing.swap(true) {
|
||||||
|
f.wgLock.Unlock()
|
||||||
|
// cancel all IO and wait for it to complete
|
||||||
|
cancelIoEx(f.handle, nil)
|
||||||
|
f.wg.Wait()
|
||||||
|
// at this point, no new IO can start
|
||||||
|
windows.Close(f.handle)
|
||||||
|
f.handle = 0
|
||||||
|
} else {
|
||||||
|
f.wgLock.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close closes a win32File.
|
||||||
|
func (f *win32File) Close() error {
|
||||||
|
f.closeHandle()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// prepareIo prepares for a new IO operation.
|
||||||
|
// The caller must call f.wg.Done() when the IO is finished, prior to Close() returning.
|
||||||
|
func (f *win32File) prepareIo() (*ioOperation, error) {
|
||||||
|
f.wgLock.RLock()
|
||||||
|
if f.closing.isSet() {
|
||||||
|
f.wgLock.RUnlock()
|
||||||
|
return nil, ErrFileClosed
|
||||||
|
}
|
||||||
|
f.wg.Add(1)
|
||||||
|
f.wgLock.RUnlock()
|
||||||
|
c := &ioOperation{}
|
||||||
|
c.ch = make(chan ioResult)
|
||||||
|
return c, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ioCompletionProcessor processes completed async IOs forever
|
||||||
|
func ioCompletionProcessor(h windows.Handle) {
|
||||||
|
for {
|
||||||
|
var bytes uint32
|
||||||
|
var key uintptr
|
||||||
|
var op *ioOperation
|
||||||
|
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
|
||||||
|
if op == nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
op.ch <- ioResult{bytes, err}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
||||||
|
// the operation has actually completed.
|
||||||
|
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||||
|
if err != windows.ERROR_IO_PENDING {
|
||||||
|
return int(bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
if f.closing.isSet() {
|
||||||
|
cancelIoEx(f.handle, &c.o)
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeout timeoutChan
|
||||||
|
if d != nil {
|
||||||
|
d.channelLock.Lock()
|
||||||
|
timeout = d.channel
|
||||||
|
d.channelLock.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
var r ioResult
|
||||||
|
select {
|
||||||
|
case r = <-c.ch:
|
||||||
|
err = r.err
|
||||||
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
|
if f.closing.isSet() {
|
||||||
|
err = ErrFileClosed
|
||||||
|
}
|
||||||
|
} else if err != nil && f.socket {
|
||||||
|
// err is from Win32. Query the overlapped structure to get the winsock error.
|
||||||
|
var bytes, flags uint32
|
||||||
|
err = wsaGetOverlappedResult(f.handle, &c.o, &bytes, false, &flags)
|
||||||
|
}
|
||||||
|
case <-timeout:
|
||||||
|
cancelIoEx(f.handle, &c.o)
|
||||||
|
r = <-c.ch
|
||||||
|
err = r.err
|
||||||
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
|
err = ErrTimeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// runtime.KeepAlive is needed, as c is passed via native
|
||||||
|
// code to ioCompletionProcessor, c must remain alive
|
||||||
|
// until the channel read is complete.
|
||||||
|
runtime.KeepAlive(c)
|
||||||
|
return int(r.bytes), err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads from a file handle.
|
||||||
|
func (f *win32File) Read(b []byte) (int, error) {
|
||||||
|
c, err := f.prepareIo()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
if f.readDeadline.timedout.isSet() {
|
||||||
|
return 0, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||||
|
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
||||||
|
runtime.KeepAlive(b)
|
||||||
|
|
||||||
|
// Handle EOF conditions.
|
||||||
|
if err == nil && n == 0 && len(b) != 0 {
|
||||||
|
return 0, io.EOF
|
||||||
|
} else if err == windows.ERROR_BROKEN_PIPE {
|
||||||
|
return 0, io.EOF
|
||||||
|
} else {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes to a file handle.
|
||||||
|
func (f *win32File) Write(b []byte) (int, error) {
|
||||||
|
c, err := f.prepareIo()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
defer f.wg.Done()
|
||||||
|
|
||||||
|
if f.writeDeadline.timedout.isSet() {
|
||||||
|
return 0, ErrTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
var bytes uint32
|
||||||
|
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||||
|
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
||||||
|
runtime.KeepAlive(b)
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) SetReadDeadline(deadline time.Time) error {
|
||||||
|
return f.readDeadline.set(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
||||||
|
return f.writeDeadline.set(deadline)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) Flush() error {
|
||||||
|
return windows.FlushFileBuffers(f.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32File) Fd() uintptr {
|
||||||
|
return uintptr(f.handle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *deadlineHandler) set(deadline time.Time) error {
|
||||||
|
d.setLock.Lock()
|
||||||
|
defer d.setLock.Unlock()
|
||||||
|
|
||||||
|
if d.timer != nil {
|
||||||
|
if !d.timer.Stop() {
|
||||||
|
<-d.channel
|
||||||
|
}
|
||||||
|
d.timer = nil
|
||||||
|
}
|
||||||
|
d.timedout.setFalse()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-d.channel:
|
||||||
|
d.channelLock.Lock()
|
||||||
|
d.channel = make(chan struct{})
|
||||||
|
d.channelLock.Unlock()
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
|
if deadline.IsZero() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
timeoutIO := func() {
|
||||||
|
d.timedout.setTrue()
|
||||||
|
close(d.channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
duration := deadline.Sub(now)
|
||||||
|
if deadline.After(now) {
|
||||||
|
// Deadline is in the future, set a timer to wait
|
||||||
|
d.timer = time.AfterFunc(duration, timeoutIO)
|
||||||
|
} else {
|
||||||
|
// Deadline is in the past. Cancel all pending IO now.
|
||||||
|
timeoutIO()
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
9
ipc/winpipe/mksyscall.go
Normal file
9
ipc/winpipe/mksyscall.go
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2005 Microsoft
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package winpipe
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go
|
||||||
509
ipc/winpipe/pipe.go
Normal file
509
ipc/winpipe/pipe.go
Normal file
@@ -0,0 +1,509 @@
|
|||||||
|
// +build windows
|
||||||
|
|
||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2005 Microsoft
|
||||||
|
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package winpipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"runtime"
|
||||||
|
"time"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
|
||||||
|
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
|
||||||
|
//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
|
||||||
|
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||||
|
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||||
|
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
||||||
|
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
||||||
|
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||||
|
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||||
|
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
||||||
|
|
||||||
|
type ioStatusBlock struct {
|
||||||
|
Status, Information uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type objectAttributes struct {
|
||||||
|
Length uintptr
|
||||||
|
RootDirectory uintptr
|
||||||
|
ObjectName *unicodeString
|
||||||
|
Attributes uintptr
|
||||||
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
|
SecurityQoS uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type unicodeString struct {
|
||||||
|
Length uint16
|
||||||
|
MaximumLength uint16
|
||||||
|
Buffer uintptr
|
||||||
|
}
|
||||||
|
|
||||||
|
type ntstatus int32
|
||||||
|
|
||||||
|
func (status ntstatus) Err() error {
|
||||||
|
if status >= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return rtlNtStatusToDosError(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
cSECURITY_SQOS_PRESENT = 0x100000
|
||||||
|
cSECURITY_ANONYMOUS = 0
|
||||||
|
|
||||||
|
cPIPE_TYPE_MESSAGE = 4
|
||||||
|
|
||||||
|
cPIPE_READMODE_MESSAGE = 2
|
||||||
|
|
||||||
|
cFILE_OPEN = 1
|
||||||
|
cFILE_CREATE = 2
|
||||||
|
|
||||||
|
cFILE_PIPE_MESSAGE_TYPE = 1
|
||||||
|
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrPipeListenerClosed is returned for pipe operations on listeners that have been closed.
|
||||||
|
// This error should match net.errClosing since docker takes a dependency on its text.
|
||||||
|
ErrPipeListenerClosed = errors.New("use of closed network connection")
|
||||||
|
|
||||||
|
errPipeWriteClosed = errors.New("pipe has been closed for write")
|
||||||
|
)
|
||||||
|
|
||||||
|
type win32Pipe struct {
|
||||||
|
*win32File
|
||||||
|
path string
|
||||||
|
}
|
||||||
|
|
||||||
|
type win32MessageBytePipe struct {
|
||||||
|
win32Pipe
|
||||||
|
writeClosed bool
|
||||||
|
readEOF bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type pipeAddress string
|
||||||
|
|
||||||
|
func (f *win32Pipe) LocalAddr() net.Addr {
|
||||||
|
return pipeAddress(f.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32Pipe) RemoteAddr() net.Addr {
|
||||||
|
return pipeAddress(f.path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *win32Pipe) SetDeadline(t time.Time) error {
|
||||||
|
f.SetReadDeadline(t)
|
||||||
|
f.SetWriteDeadline(t)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseWrite closes the write side of a message pipe in byte mode.
|
||||||
|
func (f *win32MessageBytePipe) CloseWrite() error {
|
||||||
|
if f.writeClosed {
|
||||||
|
return errPipeWriteClosed
|
||||||
|
}
|
||||||
|
err := f.win32File.Flush()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = f.win32File.Write(nil)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
f.writeClosed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write writes bytes to a message pipe in byte mode. Zero-byte writes are ignored, since
|
||||||
|
// they are used to implement CloseWrite().
|
||||||
|
func (f *win32MessageBytePipe) Write(b []byte) (int, error) {
|
||||||
|
if f.writeClosed {
|
||||||
|
return 0, errPipeWriteClosed
|
||||||
|
}
|
||||||
|
if len(b) == 0 {
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
return f.win32File.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read reads bytes from a message pipe in byte mode. A read of a zero-byte message on a message
|
||||||
|
// mode pipe will return io.EOF, as will all subsequent reads.
|
||||||
|
func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
||||||
|
if f.readEOF {
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
n, err := f.win32File.Read(b)
|
||||||
|
if err == io.EOF {
|
||||||
|
// If this was the result of a zero-byte read, then
|
||||||
|
// it is possible that the read was due to a zero-size
|
||||||
|
// message. Since we are simulating CloseWrite with a
|
||||||
|
// zero-byte message, ensure that all future Read() calls
|
||||||
|
// also return EOF.
|
||||||
|
f.readEOF = true
|
||||||
|
} else if err == windows.ERROR_MORE_DATA {
|
||||||
|
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||||
|
// and the message still has more bytes. Treat this as a success, since
|
||||||
|
// this package presents all named pipes as byte streams.
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s pipeAddress) Network() string {
|
||||||
|
return "pipe"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s pipeAddress) String() string {
|
||||||
|
return string(s)
|
||||||
|
}
|
||||||
|
|
||||||
|
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||||
|
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return windows.Handle(0), ctx.Err()
|
||||||
|
default:
|
||||||
|
h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
||||||
|
if err == nil {
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
if err != windows.ERROR_PIPE_BUSY {
|
||||||
|
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||||
|
}
|
||||||
|
// Wait 10 msec and try again. This is a rather simplistic
|
||||||
|
// view, as we always try each 10 milliseconds.
|
||||||
|
time.Sleep(time.Millisecond * 10)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||||
|
// takes longer than the specified duration. If timeout is nil, then we use
|
||||||
|
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||||
|
func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
|
||||||
|
var absTimeout time.Time
|
||||||
|
if timeout != nil {
|
||||||
|
absTimeout = time.Now().Add(*timeout)
|
||||||
|
} else {
|
||||||
|
absTimeout = time.Now().Add(time.Second * 2)
|
||||||
|
}
|
||||||
|
ctx, _ := context.WithDeadline(context.Background(), absTimeout)
|
||||||
|
conn, err := DialPipeContext(ctx, path, expectedOwner)
|
||||||
|
if err == context.DeadlineExceeded {
|
||||||
|
return nil, ErrTimeout
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||||
|
// cancellation or timeout.
|
||||||
|
func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
|
||||||
|
var err error
|
||||||
|
var h windows.Handle
|
||||||
|
h, err = tryDialPipe(ctx, &path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if expectedOwner != nil {
|
||||||
|
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
realOwner, _, err := sd.Owner()
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if !realOwner.Equals(expectedOwner) {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, windows.ERROR_ACCESS_DENIED
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var flags uint32
|
||||||
|
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
f, err := makeWin32File(h)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the pipe is in message mode, return a message byte pipe, which
|
||||||
|
// supports CloseWrite().
|
||||||
|
if flags&cPIPE_TYPE_MESSAGE != 0 {
|
||||||
|
return &win32MessageBytePipe{
|
||||||
|
win32Pipe: win32Pipe{win32File: f, path: path},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &win32Pipe{win32File: f, path: path}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type acceptResponse struct {
|
||||||
|
f *win32File
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
type win32PipeListener struct {
|
||||||
|
firstHandle windows.Handle
|
||||||
|
path string
|
||||||
|
config PipeConfig
|
||||||
|
acceptCh chan (chan acceptResponse)
|
||||||
|
closeCh chan int
|
||||||
|
doneCh chan int
|
||||||
|
}
|
||||||
|
|
||||||
|
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
|
||||||
|
path16, err := windows.UTF16FromString(path)
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
var oa objectAttributes
|
||||||
|
oa.Length = unsafe.Sizeof(oa)
|
||||||
|
|
||||||
|
var ntPath unicodeString
|
||||||
|
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(ntPath.Buffer))
|
||||||
|
oa.ObjectName = &ntPath
|
||||||
|
|
||||||
|
// The security descriptor is only needed for the first pipe.
|
||||||
|
if first {
|
||||||
|
if sd != nil {
|
||||||
|
oa.SecurityDescriptor = sd
|
||||||
|
} else {
|
||||||
|
// Construct the default named pipe security descriptor.
|
||||||
|
var dacl uintptr
|
||||||
|
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||||
|
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
||||||
|
}
|
||||||
|
defer windows.LocalFree(windows.Handle(dacl))
|
||||||
|
sd, err := windows.NewSecurityDescriptor()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("creating new security descriptor: %s", err)
|
||||||
|
}
|
||||||
|
if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
|
||||||
|
return 0, fmt.Errorf("assigning dacl: %s", err)
|
||||||
|
}
|
||||||
|
sd, err = sd.ToSelfRelative()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("converting to self-relative: %s", err)
|
||||||
|
}
|
||||||
|
oa.SecurityDescriptor = sd
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typ := uint32(cFILE_PIPE_REJECT_REMOTE_CLIENTS)
|
||||||
|
if c.MessageMode {
|
||||||
|
typ |= cFILE_PIPE_MESSAGE_TYPE
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := uint32(cFILE_OPEN)
|
||||||
|
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||||
|
if first {
|
||||||
|
disposition = cFILE_CREATE
|
||||||
|
// By not asking for read or write access, the named pipe file system
|
||||||
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
|
// client connections until the next call with first == false.
|
||||||
|
access = windows.SYNCHRONIZE
|
||||||
|
}
|
||||||
|
|
||||||
|
timeout := int64(-50 * 10000) // 50ms
|
||||||
|
|
||||||
|
var (
|
||||||
|
h windows.Handle
|
||||||
|
iosb ioStatusBlock
|
||||||
|
)
|
||||||
|
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
||||||
|
if err != nil {
|
||||||
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
|
}
|
||||||
|
|
||||||
|
runtime.KeepAlive(ntPath)
|
||||||
|
return h, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
||||||
|
h, err := makeServerPipeHandle(l.path, nil, &l.config, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
f, err := makeWin32File(h)
|
||||||
|
if err != nil {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) makeConnectedServerPipe() (*win32File, error) {
|
||||||
|
p, err := l.makeServerPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for the client to connect.
|
||||||
|
ch := make(chan error)
|
||||||
|
go func(p *win32File) {
|
||||||
|
ch <- connectPipe(p)
|
||||||
|
}(p)
|
||||||
|
|
||||||
|
select {
|
||||||
|
case err = <-ch:
|
||||||
|
if err != nil {
|
||||||
|
p.Close()
|
||||||
|
p = nil
|
||||||
|
}
|
||||||
|
case <-l.closeCh:
|
||||||
|
// Abort the connect request by closing the handle.
|
||||||
|
p.Close()
|
||||||
|
p = nil
|
||||||
|
err = <-ch
|
||||||
|
if err == nil || err == ErrFileClosed {
|
||||||
|
err = ErrPipeListenerClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return p, err
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) listenerRoutine() {
|
||||||
|
closed := false
|
||||||
|
for !closed {
|
||||||
|
select {
|
||||||
|
case <-l.closeCh:
|
||||||
|
closed = true
|
||||||
|
case responseCh := <-l.acceptCh:
|
||||||
|
var (
|
||||||
|
p *win32File
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
p, err = l.makeConnectedServerPipe()
|
||||||
|
// If the connection was immediately closed by the client, try
|
||||||
|
// again.
|
||||||
|
if err != windows.ERROR_NO_DATA {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
responseCh <- acceptResponse{p, err}
|
||||||
|
closed = err == ErrPipeListenerClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
windows.Close(l.firstHandle)
|
||||||
|
l.firstHandle = 0
|
||||||
|
// Notify Close() and Accept() callers that the handle has been closed.
|
||||||
|
close(l.doneCh)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PipeConfig contain configuration for the pipe listener.
|
||||||
|
type PipeConfig struct {
|
||||||
|
// SecurityDescriptor contains a Windows security descriptor.
|
||||||
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
|
|
||||||
|
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||||
|
// case the pipe is read in byte mode by default. The only practical difference in
|
||||||
|
// this implementation is that CloseWrite() is only supported for message mode pipes;
|
||||||
|
// CloseWrite() is implemented as a zero-byte write, but zero-byte writes are only
|
||||||
|
// transferred to the reader (and returned as io.EOF in this implementation)
|
||||||
|
// when the pipe is in message mode.
|
||||||
|
MessageMode bool
|
||||||
|
|
||||||
|
// InputBufferSize specifies the size the input buffer, in bytes.
|
||||||
|
InputBufferSize int32
|
||||||
|
|
||||||
|
// OutputBufferSize specifies the size the input buffer, in bytes.
|
||||||
|
OutputBufferSize int32
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||||
|
// The pipe must not already exist.
|
||||||
|
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||||
|
if c == nil {
|
||||||
|
c = &PipeConfig{}
|
||||||
|
}
|
||||||
|
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
l := &win32PipeListener{
|
||||||
|
firstHandle: h,
|
||||||
|
path: path,
|
||||||
|
config: *c,
|
||||||
|
acceptCh: make(chan (chan acceptResponse)),
|
||||||
|
closeCh: make(chan int),
|
||||||
|
doneCh: make(chan int),
|
||||||
|
}
|
||||||
|
go l.listenerRoutine()
|
||||||
|
return l, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func connectPipe(p *win32File) error {
|
||||||
|
c, err := p.prepareIo()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer p.wg.Done()
|
||||||
|
|
||||||
|
err = connectNamedPipe(p.handle, &c.o)
|
||||||
|
_, err = p.asyncIo(c, nil, 0, err)
|
||||||
|
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) Accept() (net.Conn, error) {
|
||||||
|
ch := make(chan acceptResponse)
|
||||||
|
select {
|
||||||
|
case l.acceptCh <- ch:
|
||||||
|
response := <-ch
|
||||||
|
err := response.err
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if l.config.MessageMode {
|
||||||
|
return &win32MessageBytePipe{
|
||||||
|
win32Pipe: win32Pipe{win32File: response.f, path: l.path},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
return &win32Pipe{win32File: response.f, path: l.path}, nil
|
||||||
|
case <-l.doneCh:
|
||||||
|
return nil, ErrPipeListenerClosed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) Close() error {
|
||||||
|
select {
|
||||||
|
case l.closeCh <- 1:
|
||||||
|
<-l.doneCh
|
||||||
|
case <-l.doneCh:
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *win32PipeListener) Addr() net.Addr {
|
||||||
|
return pipeAddress(l.path)
|
||||||
|
}
|
||||||
238
ipc/winpipe/zsyscall_windows.go
Normal file
238
ipc/winpipe/zsyscall_windows.go
Normal file
@@ -0,0 +1,238 @@
|
|||||||
|
// Code generated by 'go generate'; DO NOT EDIT.
|
||||||
|
|
||||||
|
package winpipe
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
// TODO: add more here, after collecting data on the common
|
||||||
|
// error values see on Windows. (perhaps when running
|
||||||
|
// all.bat?)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||||
|
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||||
|
|
||||||
|
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||||
|
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
|
||||||
|
procCreateFileW = modkernel32.NewProc("CreateFileW")
|
||||||
|
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
|
||||||
|
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
|
||||||
|
procLocalAlloc = modkernel32.NewProc("LocalAlloc")
|
||||||
|
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
|
||||||
|
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||||
|
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||||
|
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||||
|
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||||
|
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||||
|
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||||
|
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
|
||||||
|
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||||
|
)
|
||||||
|
|
||||||
|
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == windows.InvalidHandle {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
|
var _p0 *uint16
|
||||||
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == windows.InvalidHandle {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
|
||||||
|
r0, _, _ := syscall.Syscall(procLocalAlloc.Addr(), 2, uintptr(uFlags), uintptr(length), 0)
|
||||||
|
ptr = uintptr(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
||||||
|
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
||||||
|
status = ntstatus(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func rtlNtStatusToDosError(status ntstatus) (winerr error) {
|
||||||
|
r0, _, _ := syscall.Syscall(procRtlNtStatusToDosErrorNoTeb.Addr(), 1, uintptr(status), 0, 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
winerr = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) {
|
||||||
|
r0, _, _ := syscall.Syscall6(procRtlDosPathNameToNtPathName_U.Addr(), 4, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(ntName)), uintptr(filePart), uintptr(reserved), 0, 0)
|
||||||
|
status = ntstatus(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
|
||||||
|
r0, _, _ := syscall.Syscall(procRtlDefaultNpAcl.Addr(), 1, uintptr(unsafe.Pointer(dacl)), 0, 0)
|
||||||
|
status = ntstatus(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
||||||
|
newport = windows.Handle(r0)
|
||||||
|
if newport == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||||
|
var _p0 uint32
|
||||||
|
if wait {
|
||||||
|
_p0 = 1
|
||||||
|
} else {
|
||||||
|
_p0 = 0
|
||||||
|
}
|
||||||
|
r1, _, e1 := syscall.Syscall6(procWSAGetOverlappedResult.Addr(), 5, uintptr(h), uintptr(unsafe.Pointer(o)), uintptr(unsafe.Pointer(bytes)), uintptr(_p0), uintptr(unsafe.Pointer(flags)), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
18
main.go
18
main.go
@@ -40,31 +40,19 @@ func warning() {
|
|||||||
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
shouldQuit := os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
|
|
||||||
|
|
||||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
fmt.Fprintln(os.Stderr, "W G")
|
||||||
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
|
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
|
||||||
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and foolish. This G")
|
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
|
||||||
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
|
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
|
||||||
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
|
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
|
||||||
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
|
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
|
||||||
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
|
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
|
||||||
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
|
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
|
||||||
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
|
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
|
||||||
if shouldQuit {
|
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W the advice here, please first export this G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W environment variable: G")
|
|
||||||
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
|
|
||||||
}
|
|
||||||
fmt.Fprintln(os.Stderr, "W G")
|
fmt.Fprintln(os.Stderr, "W G")
|
||||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
||||||
|
|
||||||
if shouldQuit {
|
|
||||||
os.Exit(1)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
@@ -75,8 +63,6 @@ func main() {
|
|||||||
|
|
||||||
warning()
|
warning()
|
||||||
|
|
||||||
// parse arguments
|
|
||||||
|
|
||||||
var foreground bool
|
var foreground bool
|
||||||
var interfaceName string
|
var interfaceName string
|
||||||
if len(os.Args) < 2 || len(os.Args) > 3 {
|
if len(os.Args) < 2 || len(os.Args) > 3 {
|
||||||
@@ -125,7 +111,7 @@ func main() {
|
|||||||
|
|
||||||
// open TUN device (or use supplied fd)
|
// open TUN device (or use supplied fd)
|
||||||
|
|
||||||
tun, err := func() (tun.TUNDevice, error) {
|
tun, err := func() (tun.Device, error) {
|
||||||
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
tunFdStr := os.Getenv(ENV_WG_TUN_FD)
|
||||||
if tunFdStr == "" {
|
if tunFdStr == "" {
|
||||||
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
return tun.CreateTUN(interfaceName, device.DefaultMTU)
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func main() {
|
|||||||
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
||||||
logger.Debug.Println("Debug log enabled")
|
logger.Debug.Println("Debug log enabled")
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName)
|
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tun.Name()
|
realInterfaceName, err2 := tun.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
|
|||||||
@@ -76,7 +76,7 @@ func (rate *Ratelimiter) Init() {
|
|||||||
|
|
||||||
for key, entry := range rate.tableIPv4 {
|
for key, entry := range rate.tableIPv4 {
|
||||||
entry.Lock()
|
entry.Lock()
|
||||||
if time.Now().Sub(entry.lastTime) > garbageCollectTime {
|
if time.Since(entry.lastTime) > garbageCollectTime {
|
||||||
delete(rate.tableIPv4, key)
|
delete(rate.tableIPv4, key)
|
||||||
}
|
}
|
||||||
entry.Unlock()
|
entry.Unlock()
|
||||||
@@ -84,7 +84,7 @@ func (rate *Ratelimiter) Init() {
|
|||||||
|
|
||||||
for key, entry := range rate.tableIPv6 {
|
for key, entry := range rate.tableIPv6 {
|
||||||
entry.Lock()
|
entry.Lock()
|
||||||
if time.Now().Sub(entry.lastTime) > garbageCollectTime {
|
if time.Since(entry.lastTime) > garbageCollectTime {
|
||||||
delete(rate.tableIPv6, key)
|
delete(rate.tableIPv6, key)
|
||||||
}
|
}
|
||||||
entry.Unlock()
|
entry.Unlock()
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
|
|||||||
for i := 0; i < packetsBurstable; i++ {
|
for i := 0; i < packetsBurstable; i++ {
|
||||||
Add(RatelimiterResult{
|
Add(RatelimiterResult{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "inital burst",
|
text: "initial burst",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
|
|||||||
fdset := fdSet{}
|
fdset := fdSet{}
|
||||||
fdset.set(rw.fd)
|
fdset.set(rw.fd)
|
||||||
fdset.set(closeFd)
|
fdset.set(closeFd)
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
var err error
|
||||||
|
for {
|
||||||
|
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
|
|||||||
fdset := fdSet{}
|
fdset := fdSet{}
|
||||||
fdset.set(rw.fd)
|
fdset.set(rw.fd)
|
||||||
fdset.set(closeFd)
|
fdset.set(closeFd)
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
var err error
|
||||||
|
for {
|
||||||
|
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,11 +15,15 @@ import (
|
|||||||
*/
|
*/
|
||||||
func TestMonotonic(t *testing.T) {
|
func TestMonotonic(t *testing.T) {
|
||||||
old := Now()
|
old := Now()
|
||||||
for i := 0; i < 10000; i++ {
|
for i := 0; i < 50; i++ {
|
||||||
time.Sleep(time.Nanosecond)
|
|
||||||
next := Now()
|
next := Now()
|
||||||
|
if next.After(old) {
|
||||||
|
t.Error("Whitening insufficient")
|
||||||
|
}
|
||||||
|
time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
|
||||||
|
next = Now()
|
||||||
if !next.After(old) {
|
if !next.After(old) {
|
||||||
t.Error("TAI64N, not monotonically increasing on nano-second scale")
|
t.Error("Not monotonically increasing on whitened nano-second scale")
|
||||||
}
|
}
|
||||||
old = next
|
old = next
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,93 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package tun
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"errors"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
|
||||||
)
|
|
||||||
|
|
||||||
/* Helpers for writing unit tests
|
|
||||||
*/
|
|
||||||
|
|
||||||
type DummyTUN struct {
|
|
||||||
name string
|
|
||||||
mtu int
|
|
||||||
packets chan []byte
|
|
||||||
events chan tun.TUNEvent
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) File() *os.File {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) Name() (string, error) {
|
|
||||||
return tun.name, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) MTU() (int, error) {
|
|
||||||
return tun.mtu, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) Write(d []byte, offset int) (int, error) {
|
|
||||||
tun.packets <- d[offset:]
|
|
||||||
return len(d), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) Close() error {
|
|
||||||
close(tun.events)
|
|
||||||
close(tun.packets)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) Events() chan tun.TUNEvent {
|
|
||||||
return tun.events
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *DummyTUN) Read(d []byte, offset int) (int, error) {
|
|
||||||
t, ok := <-tun.packets
|
|
||||||
if !ok {
|
|
||||||
return 0, errors.New("device closed")
|
|
||||||
}
|
|
||||||
copy(d[offset:], t)
|
|
||||||
return len(t), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func CreateDummyTUN(name string) (tun.TUNDevice, error) {
|
|
||||||
var dummy DummyTUN
|
|
||||||
dummy.mtu = 0
|
|
||||||
dummy.packets = make(chan []byte, 100)
|
|
||||||
dummy.events = make(chan tun.TUNEvent, 10)
|
|
||||||
return &dummy, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func assertEqual(t *testing.T, a []byte, b []byte) {
|
|
||||||
if bytes.Compare(a, b) != 0 {
|
|
||||||
t.Fatal(a, "!=", b)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func randDevice(t *testing.T) *Device {
|
|
||||||
sk, err := newPrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
tun, _ := CreateDummyTUN("dummy")
|
|
||||||
logger := NewLogger(LogLevelError, "")
|
|
||||||
device := NewDevice(tun, logger)
|
|
||||||
device.SetPrivateKey(sk)
|
|
||||||
return device
|
|
||||||
}
|
|
||||||
12
tun/tun.go
12
tun/tun.go
@@ -9,21 +9,21 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type TUNEvent int
|
type Event int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TUNEventUp = 1 << iota
|
EventUp = 1 << iota
|
||||||
TUNEventDown
|
EventDown
|
||||||
TUNEventMTUUpdate
|
EventMTUUpdate
|
||||||
)
|
)
|
||||||
|
|
||||||
type TUNDevice interface {
|
type Device interface {
|
||||||
File() *os.File // returns the file descriptor of the device
|
File() *os.File // returns the file descriptor of the device
|
||||||
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
|
Read([]byte, int) (int, error) // read a packet from the device (without any additional headers)
|
||||||
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
|
Write([]byte, int) (int, error) // writes a packet to the device (without any additional headers)
|
||||||
Flush() error // flush all previous writes to the device
|
Flush() error // flush all previous writes to the device
|
||||||
MTU() (int, error) // returns the MTU of the device
|
MTU() (int, error) // returns the MTU of the device
|
||||||
Name() (string, error) // fetches and returns the current name
|
Name() (string, error) // fetches and returns the current name
|
||||||
Events() chan TUNEvent // returns a constant channel of events related to the device
|
Events() chan Event // returns a constant channel of events related to the device
|
||||||
Close() error // stops the device and closes the event channel
|
Close() error // stops the device and closes the event channel
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@@ -35,23 +36,36 @@ type sockaddrCtl struct {
|
|||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
name string
|
name string
|
||||||
tunFile *os.File
|
tunFile *os.File
|
||||||
events chan TUNEvent
|
events chan Event
|
||||||
errors chan error
|
errors chan error
|
||||||
routeSocket int
|
routeSocket int
|
||||||
}
|
}
|
||||||
|
|
||||||
var sockaddrCtlSize uintptr = 32
|
var sockaddrCtlSize uintptr = 32
|
||||||
|
|
||||||
|
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
iface, err = net.InterfaceByIndex(index)
|
||||||
|
if err != nil {
|
||||||
|
if opErr, ok := err.(*net.OpError); ok {
|
||||||
|
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
|
||||||
|
time.Sleep(time.Duration(i) * time.Second / 3)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return iface, err
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||||
var (
|
var (
|
||||||
statusUp bool
|
statusUp bool
|
||||||
statusMTU int
|
statusMTU int
|
||||||
)
|
)
|
||||||
|
|
||||||
defer func() {
|
defer close(tun.events)
|
||||||
close(tun.events)
|
|
||||||
tun.routeSocket = -1
|
|
||||||
}()
|
|
||||||
|
|
||||||
data := make([]byte, os.Getpagesize())
|
data := make([]byte, os.Getpagesize())
|
||||||
for {
|
for {
|
||||||
@@ -77,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := net.InterfaceByIndex(ifindex)
|
iface, err := retryInterfaceByIndex(ifindex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.errors <- err
|
tun.errors <- err
|
||||||
return
|
return
|
||||||
@@ -86,22 +100,22 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
// Up / Down event
|
// Up / Down event
|
||||||
up := (iface.Flags & net.FlagUp) != 0
|
up := (iface.Flags & net.FlagUp) != 0
|
||||||
if up != statusUp && up {
|
if up != statusUp && up {
|
||||||
tun.events <- TUNEventUp
|
tun.events <- EventUp
|
||||||
}
|
}
|
||||||
if up != statusUp && !up {
|
if up != statusUp && !up {
|
||||||
tun.events <- TUNEventDown
|
tun.events <- EventDown
|
||||||
}
|
}
|
||||||
statusUp = up
|
statusUp = up
|
||||||
|
|
||||||
// MTU changes
|
// MTU changes
|
||||||
if iface.MTU != statusMTU {
|
if iface.MTU != statusMTU {
|
||||||
tun.events <- TUNEventMTUUpdate
|
tun.events <- EventMTUUpdate
|
||||||
}
|
}
|
||||||
statusMTU = iface.MTU
|
statusMTU = iface.MTU
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
func CreateTUN(name string, mtu int) (Device, error) {
|
||||||
ifIndex := -1
|
ifIndex := -1
|
||||||
if name != "utun" {
|
if name != "utun" {
|
||||||
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
_, err := fmt.Sscanf(name, "utun%d", &ifIndex)
|
||||||
@@ -171,10 +185,10 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
|||||||
return tun, err
|
return tun, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
|
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan TUNEvent, 10),
|
events: make(chan Event, 10),
|
||||||
errors: make(chan error, 5),
|
errors: make(chan error, 5),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -244,7 +258,7 @@ func (tun *NativeTun) File() *os.File {
|
|||||||
return tun.tunFile
|
return tun.tunFile
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Events() chan TUNEvent {
|
func (tun *NativeTun) Events() chan Event {
|
||||||
return tun.events
|
return tun.events
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,7 +300,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Flush() error {
|
func (tun *NativeTun) Flush() error {
|
||||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ const (
|
|||||||
_TUNSIFPID = 0x2000745f
|
_TUNSIFPID = 0x2000745f
|
||||||
)
|
)
|
||||||
|
|
||||||
//TODO: move into x/sys/unix
|
// TODO: move into x/sys/unix
|
||||||
const (
|
const (
|
||||||
SIOCGIFINFO_IN6 = 0xc048696c
|
SIOCGIFINFO_IN6 = 0xc048696c
|
||||||
SIOCSIFINFO_IN6 = 0xc048696d
|
SIOCSIFINFO_IN6 = 0xc048696d
|
||||||
@@ -79,7 +79,7 @@ type in6_ndireq struct {
|
|||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
name string
|
name string
|
||||||
tunFile *os.File
|
tunFile *os.File
|
||||||
events chan TUNEvent
|
events chan Event
|
||||||
errors chan error
|
errors chan error
|
||||||
routeSocket int
|
routeSocket int
|
||||||
}
|
}
|
||||||
@@ -125,16 +125,16 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
// Up / Down event
|
// Up / Down event
|
||||||
up := (iface.Flags & net.FlagUp) != 0
|
up := (iface.Flags & net.FlagUp) != 0
|
||||||
if up != statusUp && up {
|
if up != statusUp && up {
|
||||||
tun.events <- TUNEventUp
|
tun.events <- EventUp
|
||||||
}
|
}
|
||||||
if up != statusUp && !up {
|
if up != statusUp && !up {
|
||||||
tun.events <- TUNEventDown
|
tun.events <- EventDown
|
||||||
}
|
}
|
||||||
statusUp = up
|
statusUp = up
|
||||||
|
|
||||||
// MTU changes
|
// MTU changes
|
||||||
if iface.MTU != statusMTU {
|
if iface.MTU != statusMTU {
|
||||||
tun.events <- TUNEventMTUUpdate
|
tun.events <- EventMTUUpdate
|
||||||
}
|
}
|
||||||
statusMTU = iface.MTU
|
statusMTU = iface.MTU
|
||||||
}
|
}
|
||||||
@@ -246,7 +246,7 @@ func tunDestroy(name string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
func CreateTUN(name string, mtu int) (Device, error) {
|
||||||
if len(name) > unix.IFNAMSIZ-1 {
|
if len(name) > unix.IFNAMSIZ-1 {
|
||||||
return nil, errors.New("interface name too long")
|
return nil, errors.New("interface name too long")
|
||||||
}
|
}
|
||||||
@@ -365,11 +365,11 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
|||||||
return CreateTUNFromFile(tunFile, mtu)
|
return CreateTUNFromFile(tunFile, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
|
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
|
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan TUNEvent, 10),
|
events: make(chan Event, 10),
|
||||||
errors: make(chan error, 1),
|
errors: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -425,7 +425,7 @@ func (tun *NativeTun) File() *os.File {
|
|||||||
return tun.tunFile
|
return tun.tunFile
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Events() chan TUNEvent {
|
func (tun *NativeTun) Events() chan Event {
|
||||||
return tun.events
|
return tun.events
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -467,7 +467,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Flush() error {
|
func (tun *NativeTun) Flush() error {
|
||||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,11 +31,11 @@ const (
|
|||||||
|
|
||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
tunFile *os.File
|
tunFile *os.File
|
||||||
index int32 // if index
|
index int32 // if index
|
||||||
name string // name of interface
|
name string // name of interface
|
||||||
errors chan error // async error handling
|
errors chan error // async error handling
|
||||||
events chan TUNEvent // device related events
|
events chan Event // device related events
|
||||||
nopi bool // the device was pased IFF_NO_PI
|
nopi bool // the device was passed IFF_NO_PI
|
||||||
netlinkSock int
|
netlinkSock int
|
||||||
netlinkCancel *rwcancel.RWCancel
|
netlinkCancel *rwcancel.RWCancel
|
||||||
hackListenerClosed sync.Mutex
|
hackListenerClosed sync.Mutex
|
||||||
@@ -64,9 +64,9 @@ func (tun *NativeTun) routineHackListener() {
|
|||||||
}
|
}
|
||||||
switch err {
|
switch err {
|
||||||
case unix.EINVAL:
|
case unix.EINVAL:
|
||||||
tun.events <- TUNEventUp
|
tun.events <- EventUp
|
||||||
case unix.EIO:
|
case unix.EIO:
|
||||||
tun.events <- TUNEventDown
|
tun.events <- EventDown
|
||||||
default:
|
default:
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -148,14 +148,14 @@ func (tun *NativeTun) routineNetlinkListener() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||||
tun.events <- TUNEventUp
|
tun.events <- EventUp
|
||||||
}
|
}
|
||||||
|
|
||||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||||
tun.events <- TUNEventDown
|
tun.events <- EventDown
|
||||||
}
|
}
|
||||||
|
|
||||||
tun.events <- TUNEventMTUUpdate
|
tun.events <- EventMTUUpdate
|
||||||
|
|
||||||
default:
|
default:
|
||||||
remain = remain[hdr.Len:]
|
remain = remain[hdr.Len:]
|
||||||
@@ -320,7 +320,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Flush() error {
|
func (tun *NativeTun) Flush() error {
|
||||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -342,7 +342,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Events() chan TUNEvent {
|
func (tun *NativeTun) Events() chan Event {
|
||||||
return tun.events
|
return tun.events
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -364,7 +364,7 @@ func (tun *NativeTun) Close() error {
|
|||||||
return err2
|
return err2
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
func CreateTUN(name string, mtu int) (Device, error) {
|
||||||
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
|
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -400,10 +400,10 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
|||||||
return CreateTUNFromFile(fd, mtu)
|
return CreateTUNFromFile(fd, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
|
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan TUNEvent, 5),
|
events: make(chan Event, 5),
|
||||||
errors: make(chan error, 5),
|
errors: make(chan error, 5),
|
||||||
statusListenersShutdown: make(chan struct{}),
|
statusListenersShutdown: make(chan struct{}),
|
||||||
nopi: false,
|
nopi: false,
|
||||||
@@ -445,7 +445,7 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
|
|||||||
return tun, nil
|
return tun, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) {
|
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||||
err := unix.SetNonblock(fd, true)
|
err := unix.SetNonblock(fd, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
@@ -453,7 +453,7 @@ func CreateUnmonitoredTUNFromFD(fd int) (TUNDevice, string, error) {
|
|||||||
file := os.NewFile(uintptr(fd), "/dev/tun")
|
file := os.NewFile(uintptr(fd), "/dev/tun")
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan TUNEvent, 5),
|
events: make(chan Event, 5),
|
||||||
errors: make(chan error, 5),
|
errors: make(chan error, 5),
|
||||||
nopi: true,
|
nopi: true,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ const _TUNSIFMODE = 0x8004745d
|
|||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
name string
|
name string
|
||||||
tunFile *os.File
|
tunFile *os.File
|
||||||
events chan TUNEvent
|
events chan Event
|
||||||
errors chan error
|
errors chan error
|
||||||
routeSocket int
|
routeSocket int
|
||||||
}
|
}
|
||||||
@@ -42,13 +42,41 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
|
|
||||||
defer close(tun.events)
|
defer close(tun.events)
|
||||||
|
|
||||||
|
check := func() bool {
|
||||||
|
iface, err := net.InterfaceByIndex(tunIfindex)
|
||||||
|
if err != nil {
|
||||||
|
tun.errors <- err
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Up / Down event
|
||||||
|
up := (iface.Flags & net.FlagUp) != 0
|
||||||
|
if up != statusUp && up {
|
||||||
|
tun.events <- EventUp
|
||||||
|
}
|
||||||
|
if up != statusUp && !up {
|
||||||
|
tun.events <- EventDown
|
||||||
|
}
|
||||||
|
statusUp = up
|
||||||
|
|
||||||
|
// MTU changes
|
||||||
|
if iface.MTU != statusMTU {
|
||||||
|
tun.events <- EventMTUUpdate
|
||||||
|
}
|
||||||
|
statusMTU = iface.MTU
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if check() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
data := make([]byte, os.Getpagesize())
|
data := make([]byte, os.Getpagesize())
|
||||||
for {
|
for {
|
||||||
retry:
|
|
||||||
n, err := unix.Read(tun.routeSocket, data)
|
n, err := unix.Read(tun.routeSocket, data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR {
|
||||||
goto retry
|
continue
|
||||||
}
|
}
|
||||||
tun.errors <- err
|
tun.errors <- err
|
||||||
return
|
return
|
||||||
@@ -65,28 +93,9 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
if ifindex != tunIfindex {
|
if ifindex != tunIfindex {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if check() {
|
||||||
iface, err := net.InterfaceByIndex(ifindex)
|
|
||||||
if err != nil {
|
|
||||||
tun.errors <- err
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Up / Down event
|
|
||||||
up := (iface.Flags & net.FlagUp) != 0
|
|
||||||
if up != statusUp && up {
|
|
||||||
tun.events <- TUNEventUp
|
|
||||||
}
|
|
||||||
if up != statusUp && !up {
|
|
||||||
tun.events <- TUNEventDown
|
|
||||||
}
|
|
||||||
statusUp = up
|
|
||||||
|
|
||||||
// MTU changes
|
|
||||||
if iface.MTU != statusMTU {
|
|
||||||
tun.events <- TUNEventMTUUpdate
|
|
||||||
}
|
|
||||||
statusMTU = iface.MTU
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,7 +109,7 @@ func errorIsEBUSY(err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
func CreateTUN(name string, mtu int) (Device, error) {
|
||||||
ifIndex := -1
|
ifIndex := -1
|
||||||
if name != "tun" {
|
if name != "tun" {
|
||||||
_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
|
_, err := fmt.Sscanf(name, "tun%d", &ifIndex)
|
||||||
@@ -139,11 +148,10 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
|
|||||||
return tun, err
|
return tun, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
|
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||||
|
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
tunFile: file,
|
tunFile: file,
|
||||||
events: make(chan TUNEvent, 10),
|
events: make(chan Event, 10),
|
||||||
errors: make(chan error, 1),
|
errors: make(chan error, 1),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,10 +181,13 @@ func CreateTUNFromFile(file *os.File, mtu int) (TUNDevice, error) {
|
|||||||
|
|
||||||
go tun.routineRouteListener(tunIfindex)
|
go tun.routineRouteListener(tunIfindex)
|
||||||
|
|
||||||
err = tun.setMTU(mtu)
|
currentMTU, err := tun.MTU()
|
||||||
if err != nil {
|
if err != nil || currentMTU != mtu {
|
||||||
tun.Close()
|
err = tun.setMTU(mtu)
|
||||||
return nil, err
|
if err != nil {
|
||||||
|
tun.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return tun, nil
|
return tun, nil
|
||||||
@@ -197,7 +208,7 @@ func (tun *NativeTun) File() *os.File {
|
|||||||
return tun.tunFile
|
return tun.tunFile
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Events() chan TUNEvent {
|
func (tun *NativeTun) Events() chan Event {
|
||||||
return tun.events
|
return tun.events
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -239,7 +250,7 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Flush() error {
|
func (tun *NativeTun) Flush() error {
|
||||||
//TODO: can flushing be implemented by buffering and using sendmmsg?
|
// TODO: can flushing be implemented by buffering and using sendmmsg?
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -10,373 +10,266 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun/wintun"
|
"golang.zx2c4.com/wireguard/tun/wintun"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
packetExchangeMax uint32 = 256 // Number of packets that may be written at a time
|
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
|
||||||
packetExchangeAlignment uint32 = 16 // Number of bytes packets are aligned to in exchange buffers
|
spinloopRateThreshold = 800000000 / 8 // 800mbps
|
||||||
packetSizeMax uint32 = 0xf000 - packetExchangeAlignment // Maximum packet size
|
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
|
||||||
packetExchangeSize uint32 = 0x100000 // Exchange buffer size (defaults to 1MiB)
|
|
||||||
retryRate = 4 // Number of retries per second to reopen device pipe
|
|
||||||
retryTimeout = 30 // Number of seconds to tolerate adapter unavailable
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type exchgBufRead struct {
|
type rateJuggler struct {
|
||||||
data [packetExchangeSize]byte
|
current uint64
|
||||||
offset uint32
|
nextByteCount uint64
|
||||||
avail uint32
|
nextStartTime int64
|
||||||
}
|
changing int32
|
||||||
|
|
||||||
type exchgBufWrite struct {
|
|
||||||
data [packetExchangeSize]byte
|
|
||||||
offset uint32
|
|
||||||
packetNum uint32
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeTun struct {
|
type NativeTun struct {
|
||||||
wt *wintun.Wintun
|
wt *wintun.Interface
|
||||||
tunFileRead *os.File
|
handle windows.Handle
|
||||||
tunFileWrite *os.File
|
close bool
|
||||||
tunLock sync.Mutex
|
events chan Event
|
||||||
close bool
|
errors chan error
|
||||||
rdBuff *exchgBufRead
|
forcedMTU int
|
||||||
wrBuff *exchgBufWrite
|
rate rateJuggler
|
||||||
events chan TUNEvent
|
rings *wintun.RingDescriptor
|
||||||
errors chan error
|
writeLock sync.Mutex
|
||||||
forcedMTU int
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func packetAlign(size uint32) uint32 {
|
const WintunPool = wintun.Pool("WireGuard")
|
||||||
return (size + (packetExchangeAlignment - 1)) &^ (packetExchangeAlignment - 1)
|
|
||||||
|
//go:linkname procyield runtime.procyield
|
||||||
|
func procyield(cycles uint32)
|
||||||
|
|
||||||
|
//go:linkname nanotime runtime.nanotime
|
||||||
|
func nanotime() int64
|
||||||
|
|
||||||
|
//
|
||||||
|
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
||||||
|
// interface with the same name exist, it is reused.
|
||||||
|
//
|
||||||
|
func CreateTUN(ifname string, mtu int) (Device, error) {
|
||||||
|
return CreateTUNWithRequestedGUID(ifname, nil, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CreateTUN creates a Wintun adapter with the given name. Should a Wintun
|
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
||||||
// adapter with the same name exist, it is reused.
|
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
||||||
//
|
//
|
||||||
func CreateTUN(ifname string) (TUNDevice, error) {
|
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
||||||
var err error
|
var err error
|
||||||
var wt *wintun.Wintun
|
var wt *wintun.Interface
|
||||||
|
|
||||||
// Does an interface with this name already exist?
|
// Does an interface with this name already exist?
|
||||||
wt, err = wintun.GetInterface(ifname, 0)
|
wt, err = WintunPool.GetInterface(ifname)
|
||||||
if wt == nil {
|
if err == nil {
|
||||||
// Interface does not exist or an error occurred. Create one.
|
// If so, we delete it, in case it has weird residual configuration.
|
||||||
wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", 0)
|
_, err = wt.DeleteInterface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("wintun.CreateInterface: %v", err)
|
return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
|
||||||
}
|
}
|
||||||
} else if err != nil {
|
|
||||||
// Foreign interface with the same name found.
|
|
||||||
// We could create a Wintun interface under a temporary name. But, should our
|
|
||||||
// process die without deleting this interface first, the interface would remain
|
|
||||||
// orphaned.
|
|
||||||
return nil, fmt.Errorf("wintun.GetInterface: %v", err)
|
|
||||||
}
|
}
|
||||||
|
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
|
||||||
err = wt.SetInterfaceName(ifname)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
wt.DeleteInterface(0)
|
return nil, fmt.Errorf("Error creating interface: %v", err)
|
||||||
return nil, fmt.Errorf("wintun.SetInterfaceName: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = wt.FlushInterface()
|
forcedMTU := 1420
|
||||||
if err != nil {
|
if mtu > 0 {
|
||||||
wt.DeleteInterface(0)
|
forcedMTU = mtu
|
||||||
return nil, fmt.Errorf("wintun.FlushInterface: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return &NativeTun{
|
tun := &NativeTun{
|
||||||
wt: wt,
|
wt: wt,
|
||||||
rdBuff: &exchgBufRead{},
|
handle: windows.InvalidHandle,
|
||||||
wrBuff: &exchgBufWrite{},
|
events: make(chan Event, 10),
|
||||||
events: make(chan TUNEvent, 10),
|
|
||||||
errors: make(chan error, 1),
|
errors: make(chan error, 1),
|
||||||
forcedMTU: 1500,
|
forcedMTU: forcedMTU,
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) openTUN() error {
|
|
||||||
retries := retryTimeout * retryRate
|
|
||||||
if tun.close {
|
|
||||||
return os.ErrClosed
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
tun.rings, err = wintun.NewRingDescriptor()
|
||||||
name := tun.wt.DataFileName()
|
if err != nil {
|
||||||
for tun.tunFileRead == nil {
|
tun.Close()
|
||||||
tun.tunFileRead, err = os.OpenFile(name, os.O_RDONLY, 0)
|
return nil, fmt.Errorf("Error creating events: %v", err)
|
||||||
if err != nil {
|
|
||||||
if retries > 0 && !tun.close {
|
|
||||||
time.Sleep(time.Second / retryRate)
|
|
||||||
retries--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for tun.tunFileWrite == nil {
|
|
||||||
tun.tunFileWrite, err = os.OpenFile(name, os.O_WRONLY, 0)
|
|
||||||
if err != nil {
|
|
||||||
if retries > 0 && !tun.close {
|
|
||||||
time.Sleep(time.Second / retryRate)
|
|
||||||
retries--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) closeTUN() (err error) {
|
tun.handle, err = tun.wt.Register(tun.rings)
|
||||||
for tun.tunFileRead != nil {
|
if err != nil {
|
||||||
tun.tunLock.Lock()
|
tun.Close()
|
||||||
if tun.tunFileRead == nil {
|
return nil, fmt.Errorf("Error registering rings: %v", err)
|
||||||
tun.tunLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
t := tun.tunFileRead
|
|
||||||
tun.tunFileRead = nil
|
|
||||||
windows.CancelIoEx(windows.Handle(t.Fd()), nil)
|
|
||||||
err = t.Close()
|
|
||||||
tun.tunLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
for tun.tunFileWrite != nil {
|
return tun, nil
|
||||||
tun.tunLock.Lock()
|
|
||||||
if tun.tunFileWrite == nil {
|
|
||||||
tun.tunLock.Unlock()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
t := tun.tunFileWrite
|
|
||||||
tun.tunFileWrite = nil
|
|
||||||
windows.CancelIoEx(windows.Handle(t.Fd()), nil)
|
|
||||||
err2 := t.Close()
|
|
||||||
tun.tunLock.Unlock()
|
|
||||||
if err == nil {
|
|
||||||
err = err2
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) getTUN() (read *os.File, write *os.File, err error) {
|
|
||||||
read, write = tun.tunFileRead, tun.tunFileWrite
|
|
||||||
if read == nil || write == nil {
|
|
||||||
read, write = nil, nil
|
|
||||||
tun.tunLock.Lock()
|
|
||||||
if tun.tunFileRead != nil && tun.tunFileWrite != nil {
|
|
||||||
read, write = tun.tunFileRead, tun.tunFileWrite
|
|
||||||
tun.tunLock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = tun.closeTUN()
|
|
||||||
if err != nil {
|
|
||||||
tun.tunLock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
err = tun.openTUN()
|
|
||||||
if err == nil {
|
|
||||||
read, write = tun.tunFileRead, tun.tunFileWrite
|
|
||||||
}
|
|
||||||
tun.tunLock.Unlock()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Name() (string, error) {
|
func (tun *NativeTun) Name() (string, error) {
|
||||||
return tun.wt.GetInterfaceName()
|
return tun.wt.Name()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) File() *os.File {
|
func (tun *NativeTun) File() *os.File {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Events() chan TUNEvent {
|
func (tun *NativeTun) Events() chan Event {
|
||||||
return tun.events
|
return tun.events
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Close() error {
|
func (tun *NativeTun) Close() error {
|
||||||
tun.close = true
|
tun.close = true
|
||||||
err1 := tun.closeTUN()
|
if tun.rings.Send.TailMoved != 0 {
|
||||||
|
windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
|
||||||
if tun.events != nil {
|
|
||||||
close(tun.events)
|
|
||||||
}
|
}
|
||||||
|
if tun.handle != windows.InvalidHandle {
|
||||||
_, _, err2 := tun.wt.DeleteInterface(0)
|
windows.CloseHandle(tun.handle)
|
||||||
if err1 == nil {
|
|
||||||
err1 = err2
|
|
||||||
}
|
}
|
||||||
|
tun.rings.Close()
|
||||||
return err1
|
var err error
|
||||||
|
if tun.wt != nil {
|
||||||
|
_, err = tun.wt.DeleteInterface()
|
||||||
|
}
|
||||||
|
close(tun.events)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) MTU() (int, error) {
|
func (tun *NativeTun) MTU() (int, error) {
|
||||||
return tun.forcedMTU, nil
|
return tun.forcedMTU, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
|
// TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
|
||||||
func (tun *NativeTun) ForceMTU(mtu int) {
|
func (tun *NativeTun) ForceMTU(mtu int) {
|
||||||
tun.forcedMTU = mtu
|
tun.forcedMTU = mtu
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
||||||
|
|
||||||
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
||||||
|
retry:
|
||||||
select {
|
select {
|
||||||
case err := <-tun.errors:
|
case err := <-tun.errors:
|
||||||
return 0, err
|
return 0, err
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
|
if tun.close {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
|
||||||
|
if buffHead >= wintun.PacketCapacity {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
start := nanotime()
|
||||||
|
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
||||||
|
var buffTail uint32
|
||||||
for {
|
for {
|
||||||
if tun.rdBuff.offset+packetExchangeAlignment <= tun.rdBuff.avail {
|
buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
|
||||||
// Get packet from the exchange buffer.
|
if buffHead != buffTail {
|
||||||
packet := tun.rdBuff.data[tun.rdBuff.offset:]
|
|
||||||
size := *(*uint32)(unsafe.Pointer(&packet[0]))
|
|
||||||
pSize := packetAlign(packetExchangeAlignment + size)
|
|
||||||
if packetSizeMax < size || tun.rdBuff.avail < tun.rdBuff.offset+pSize {
|
|
||||||
// Invalid packet size.
|
|
||||||
tun.rdBuff.avail = 0
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
packet = packet[packetExchangeAlignment : packetExchangeAlignment+size]
|
|
||||||
|
|
||||||
// Copy data.
|
|
||||||
copy(buff[offset:], packet)
|
|
||||||
tun.rdBuff.offset += pSize
|
|
||||||
return int(size), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get TUN data pipe.
|
|
||||||
file, _, err := tun.getTUN()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fill queue.
|
|
||||||
retries := 1000
|
|
||||||
for {
|
|
||||||
n, err := file.Read(tun.rdBuff.data[:])
|
|
||||||
if err != nil {
|
|
||||||
tun.rdBuff.offset = 0
|
|
||||||
tun.rdBuff.avail = 0
|
|
||||||
pe, ok := err.(*os.PathError)
|
|
||||||
if tun.close {
|
|
||||||
return 0, os.ErrClosed
|
|
||||||
}
|
|
||||||
if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED {
|
|
||||||
retries--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok && pe.Err == windows.ERROR_HANDLE_EOF {
|
|
||||||
tun.closeTUN()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
tun.rdBuff.offset = 0
|
|
||||||
tun.rdBuff.avail = uint32(n)
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
if tun.close {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||||
|
windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
|
||||||
|
goto retry
|
||||||
|
}
|
||||||
|
procyield(1)
|
||||||
|
}
|
||||||
|
if buffTail >= wintun.PacketCapacity {
|
||||||
|
return 0, os.ErrClosed
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Note: flush() and putTunPacket() assume the caller comes only from a single thread; there's no locking.
|
buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
|
||||||
|
if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
|
||||||
|
return 0, errors.New("incomplete packet header in send ring")
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
|
||||||
|
if packet.Size > wintun.PacketSizeMax {
|
||||||
|
return 0, errors.New("packet too big in send ring")
|
||||||
|
}
|
||||||
|
|
||||||
|
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
|
||||||
|
if alignedPacketSize > buffContent {
|
||||||
|
return 0, errors.New("incomplete packet in send ring")
|
||||||
|
}
|
||||||
|
|
||||||
|
copy(buff[offset:], packet.Data[:packet.Size])
|
||||||
|
buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
|
||||||
|
atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
|
||||||
|
tun.rate.update(uint64(packet.Size))
|
||||||
|
return int(packet.Size), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Flush() error {
|
func (tun *NativeTun) Flush() error {
|
||||||
if tun.wrBuff.offset == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
// Get TUN data pipe.
|
|
||||||
_, file, err := tun.getTUN()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Flush write buffer.
|
|
||||||
retries := 1000
|
|
||||||
for {
|
|
||||||
_, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset])
|
|
||||||
tun.wrBuff.packetNum = 0
|
|
||||||
tun.wrBuff.offset = 0
|
|
||||||
if err != nil {
|
|
||||||
pe, ok := err.(*os.PathError)
|
|
||||||
if tun.close {
|
|
||||||
return os.ErrClosed
|
|
||||||
}
|
|
||||||
if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED {
|
|
||||||
retries--
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if ok && pe.Err == windows.ERROR_HANDLE_EOF {
|
|
||||||
tun.closeTUN()
|
|
||||||
break
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (tun *NativeTun) putTunPacket(buff []byte) error {
|
|
||||||
size := uint32(len(buff))
|
|
||||||
if size == 0 {
|
|
||||||
return errors.New("Empty packet")
|
|
||||||
}
|
|
||||||
if size > packetSizeMax {
|
|
||||||
return errors.New("Packet too big")
|
|
||||||
}
|
|
||||||
pSize := packetAlign(packetExchangeAlignment + size)
|
|
||||||
|
|
||||||
if tun.wrBuff.packetNum >= packetExchangeMax || tun.wrBuff.offset+pSize >= packetExchangeSize {
|
|
||||||
// Exchange buffer is full -> flush first.
|
|
||||||
err := tun.Flush()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write packet to the exchange buffer.
|
|
||||||
packet := tun.wrBuff.data[tun.wrBuff.offset : tun.wrBuff.offset+pSize]
|
|
||||||
*(*uint32)(unsafe.Pointer(&packet[0])) = size
|
|
||||||
packet = packet[packetExchangeAlignment : packetExchangeAlignment+size]
|
|
||||||
copy(packet, buff)
|
|
||||||
|
|
||||||
tun.wrBuff.packetNum++
|
|
||||||
tun.wrBuff.offset += pSize
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||||
err := tun.putTunPacket(buff[offset:])
|
if tun.close {
|
||||||
if err != nil {
|
return 0, os.ErrClosed
|
||||||
return 0, err
|
|
||||||
}
|
}
|
||||||
return len(buff) - offset, nil
|
|
||||||
|
packetSize := uint32(len(buff) - offset)
|
||||||
|
tun.rate.update(uint64(packetSize))
|
||||||
|
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
|
||||||
|
|
||||||
|
tun.writeLock.Lock()
|
||||||
|
defer tun.writeLock.Unlock()
|
||||||
|
|
||||||
|
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
|
||||||
|
if buffHead >= wintun.PacketCapacity {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
|
||||||
|
if buffTail >= wintun.PacketCapacity {
|
||||||
|
return 0, os.ErrClosed
|
||||||
|
}
|
||||||
|
|
||||||
|
buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
|
||||||
|
if alignedPacketSize > buffSpace {
|
||||||
|
return 0, nil // Dropping when ring is full.
|
||||||
|
}
|
||||||
|
|
||||||
|
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
|
||||||
|
packet.Size = packetSize
|
||||||
|
copy(packet.Data[:packetSize], buff[offset:])
|
||||||
|
atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
|
||||||
|
if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
|
||||||
|
windows.SetEvent(tun.rings.Receive.TailMoved)
|
||||||
|
}
|
||||||
|
return int(packetSize), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
// LUID returns Windows interface instance ID.
|
||||||
// GUID returns Windows adapter instance ID.
|
|
||||||
//
|
|
||||||
func (tun *NativeTun) GUID() windows.GUID {
|
|
||||||
return tun.wt.CfgInstanceID
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// GUID returns Windows adapter instance ID.
|
|
||||||
//
|
|
||||||
func (tun *NativeTun) LUID() uint64 {
|
func (tun *NativeTun) LUID() uint64 {
|
||||||
return ((uint64(tun.wt.LUIDIndex) & ((1 << 24) - 1)) << 24) | ((uint64(tun.wt.IfType) & ((1 << 16) - 1)) << 48)
|
return tun.wt.LUID()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Version returns the version of the Wintun driver and NDIS system currently loaded.
|
||||||
|
func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
|
||||||
|
return tun.wt.Version()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (rate *rateJuggler) update(packetLen uint64) {
|
||||||
|
now := nanotime()
|
||||||
|
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
|
||||||
|
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
|
||||||
|
if period >= rateMeasurementGranularity {
|
||||||
|
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
atomic.StoreInt64(&rate.nextStartTime, now)
|
||||||
|
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
|
||||||
|
atomic.StoreUint64(&rate.nextByteCount, 0)
|
||||||
|
atomic.StoreInt32(&rate.changing, 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,41 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package guid
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys clsidFromString(lpsz *uint16, pclsid *windows.GUID) (err error) [failretval!=0] = ole32.CLSIDFromString
|
|
||||||
|
|
||||||
//
|
|
||||||
// FromString parses "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}" string to GUID.
|
|
||||||
//
|
|
||||||
func FromString(str string) (*windows.GUID, error) {
|
|
||||||
strUTF16, err := syscall.UTF16PtrFromString(str)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
guid := &windows.GUID{}
|
|
||||||
err = clsidFromString(strUTF16, guid)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return guid, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//
|
|
||||||
// ToString function converts GUID to string
|
|
||||||
// "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}".
|
|
||||||
//
|
|
||||||
// The resulting string is uppercase.
|
|
||||||
//
|
|
||||||
func ToString(guid *windows.GUID) string {
|
|
||||||
return fmt.Sprintf("{%08X-%04X-%04X-%04X-%012X}", guid.Data1, guid.Data2, guid.Data3, guid.Data4[:2], guid.Data4[2:])
|
|
||||||
}
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package guid
|
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zguid_windows.go guid_windows.go
|
|
||||||
@@ -1,55 +0,0 @@
|
|||||||
// Code generated by 'go generate'; DO NOT EDIT.
|
|
||||||
|
|
||||||
package guid
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var _ unsafe.Pointer
|
|
||||||
|
|
||||||
// Do the interface allocations only once for common
|
|
||||||
// Errno values.
|
|
||||||
const (
|
|
||||||
errnoERROR_IO_PENDING = 997
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
|
||||||
)
|
|
||||||
|
|
||||||
// errnoErr returns common boxed Errno values, to prevent
|
|
||||||
// allocations at runtime.
|
|
||||||
func errnoErr(e syscall.Errno) error {
|
|
||||||
switch e {
|
|
||||||
case 0:
|
|
||||||
return nil
|
|
||||||
case errnoERROR_IO_PENDING:
|
|
||||||
return errERROR_IO_PENDING
|
|
||||||
}
|
|
||||||
// TODO: add more here, after collecting data on the common
|
|
||||||
// error values see on Windows. (perhaps when running
|
|
||||||
// all.bat?)
|
|
||||||
return e
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
modole32 = windows.NewLazySystemDLL("ole32.dll")
|
|
||||||
|
|
||||||
procCLSIDFromString = modole32.NewProc("CLSIDFromString")
|
|
||||||
)
|
|
||||||
|
|
||||||
func clsidFromString(lpsz *uint16, pclsid *windows.GUID) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procCLSIDFromString.Addr(), 2, uintptr(unsafe.Pointer(lpsz)), uintptr(unsafe.Pointer(pclsid)), 0)
|
|
||||||
if r1 != 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
25
tun/wintun/iphlpapi/conversion_windows.go
Normal file
25
tun/wintun/iphlpapi/conversion_windows.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package iphlpapi
|
||||||
|
|
||||||
|
import "golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
|
||||||
|
//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
|
||||||
|
|
||||||
|
func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
|
||||||
|
var luid uint64
|
||||||
|
var guid windows.GUID
|
||||||
|
err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
err = convertInterfaceLUIDToGUID(&luid, &guid)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &guid, nil
|
||||||
|
}
|
||||||
8
tun/wintun/iphlpapi/mksyscall.go
Normal file
8
tun/wintun/iphlpapi/mksyscall.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package iphlpapi
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go
|
||||||
60
tun/wintun/iphlpapi/zsyscall_windows.go
Normal file
60
tun/wintun/iphlpapi/zsyscall_windows.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Code generated by 'go generate'; DO NOT EDIT.
|
||||||
|
|
||||||
|
package iphlpapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
// TODO: add more here, after collecting data on the common
|
||||||
|
// error values see on Windows. (perhaps when running
|
||||||
|
// all.bat?)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||||
|
|
||||||
|
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
||||||
|
procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
|
||||||
|
)
|
||||||
|
|
||||||
|
func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
98
tun/wintun/namespace_windows.go
Normal file
98
tun/wintun/namespace_windows.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package wintun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/hex"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/crypto/blake2s"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
"golang.org/x/text/unicode/norm"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
wintunObjectSecurityAttributes *windows.SecurityAttributes
|
||||||
|
hasInitializedNamespace bool
|
||||||
|
initializingNamespace sync.Mutex
|
||||||
|
)
|
||||||
|
|
||||||
|
func initializeNamespace() error {
|
||||||
|
initializingNamespace.Lock()
|
||||||
|
defer initializingNamespace.Unlock()
|
||||||
|
if hasInitializedNamespace {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
|
||||||
|
}
|
||||||
|
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
|
||||||
|
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||||
|
SecurityDescriptor: sd,
|
||||||
|
}
|
||||||
|
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
|
||||||
|
}
|
||||||
|
err = boundary.AddSid(sid)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
|
||||||
|
}
|
||||||
|
for {
|
||||||
|
_, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
|
||||||
|
if err == windows.ERROR_ALREADY_EXISTS {
|
||||||
|
_, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
|
||||||
|
if err == windows.ERROR_PATH_NOT_FOUND {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
hasInitializedNamespace = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pool Pool) takeNameMutex() (windows.Handle, error) {
|
||||||
|
err := initializeNamespace()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
|
||||||
|
b2, _ := blake2s.New256(nil)
|
||||||
|
b2.Write([]byte(mutexLabel))
|
||||||
|
b2.Write(norm.NFC.Bytes([]byte(string(pool))))
|
||||||
|
mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
|
||||||
|
mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("Error creating name mutex: %v", err)
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
|
||||||
|
if err != nil {
|
||||||
|
windows.CloseHandle(mutex)
|
||||||
|
return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
|
||||||
|
}
|
||||||
|
if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
|
||||||
|
windows.CloseHandle(mutex)
|
||||||
|
return 0, errors.New("Error with event trigger of name mutex")
|
||||||
|
}
|
||||||
|
return mutex, nil
|
||||||
|
}
|
||||||
8
tun/wintun/namespaceapi/mksyscall.go
Normal file
8
tun/wintun/namespaceapi/mksyscall.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package namespaceapi
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go
|
||||||
83
tun/wintun/namespaceapi/namespaceapi_windows.go
Normal file
83
tun/wintun/namespaceapi/namespaceapi_windows.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package namespaceapi
|
||||||
|
|
||||||
|
import "golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
|
||||||
|
//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
|
||||||
|
//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
|
||||||
|
//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
|
||||||
|
//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
|
||||||
|
//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
|
||||||
|
|
||||||
|
// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
|
||||||
|
type BoundaryDescriptor windows.Handle
|
||||||
|
|
||||||
|
// CreateBoundaryDescriptor creates a boundary descriptor.
|
||||||
|
func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
|
||||||
|
name16, err := windows.UTF16PtrFromString(name)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
handle, err := createBoundaryDescriptor(name16, 0)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return BoundaryDescriptor(handle), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete deletes the specified boundary descriptor.
|
||||||
|
func (bd BoundaryDescriptor) Delete() {
|
||||||
|
deleteBoundaryDescriptor(windows.Handle(bd))
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddSid adds a security identifier (SID) to the specified boundary descriptor.
|
||||||
|
func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
|
||||||
|
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
|
||||||
|
}
|
||||||
|
|
||||||
|
// PrivateNamespace represents a private namespace.
|
||||||
|
type PrivateNamespace windows.Handle
|
||||||
|
|
||||||
|
// CreatePrivateNamespace creates a private namespace.
|
||||||
|
func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
|
||||||
|
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return PrivateNamespace(handle), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenPrivateNamespace opens a private namespace.
|
||||||
|
func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
|
||||||
|
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return PrivateNamespace(handle), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
|
||||||
|
type ClosePrivateNamespaceFlags uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
|
||||||
|
PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
|
||||||
|
)
|
||||||
|
|
||||||
|
// Close closes an open namespace handle.
|
||||||
|
func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
|
||||||
|
return closePrivateNamespace(windows.Handle(pns), uint32(flags))
|
||||||
|
}
|
||||||
116
tun/wintun/namespaceapi/zsyscall_windows.go
Normal file
116
tun/wintun/namespaceapi/zsyscall_windows.go
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
// Code generated by 'go generate'; DO NOT EDIT.
|
||||||
|
|
||||||
|
package namespaceapi
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
// TODO: add more here, after collecting data on the common
|
||||||
|
// error values see on Windows. (perhaps when running
|
||||||
|
// all.bat?)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
|
|
||||||
|
procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
|
||||||
|
procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
|
||||||
|
procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
|
||||||
|
procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
|
||||||
|
procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
|
||||||
|
procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
|
||||||
|
)
|
||||||
|
|
||||||
|
func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
|
||||||
|
syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
|
||||||
|
r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
|
||||||
|
handle = windows.Handle(r0)
|
||||||
|
if handle == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
8
tun/wintun/nci/mksyscall.go
Normal file
8
tun/wintun/nci/mksyscall.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package nci
|
||||||
|
|
||||||
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go
|
||||||
28
tun/wintun/nci/nci_windows.go
Normal file
28
tun/wintun/nci/nci_windows.go
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package nci
|
||||||
|
|
||||||
|
import "golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
//sys nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) = nci.NciSetConnectionName
|
||||||
|
//sys nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) = nci.NciGetConnectionName
|
||||||
|
|
||||||
|
func SetConnectionName(guid *windows.GUID, newName string) error {
|
||||||
|
newName16, err := windows.UTF16PtrFromString(newName)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nciSetConnectionName(guid, newName16)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ConnectionName(guid *windows.GUID) (string, error) {
|
||||||
|
var name [0x400]uint16
|
||||||
|
err := nciGetConnectionName(guid, &name[0], uint32(len(name)*2), nil)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return windows.UTF16ToString(name[:]), nil
|
||||||
|
}
|
||||||
60
tun/wintun/nci/zsyscall_windows.go
Normal file
60
tun/wintun/nci/zsyscall_windows.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
// Code generated by 'go generate'; DO NOT EDIT.
|
||||||
|
|
||||||
|
package nci
|
||||||
|
|
||||||
|
import (
|
||||||
|
"syscall"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
var _ unsafe.Pointer
|
||||||
|
|
||||||
|
// Do the interface allocations only once for common
|
||||||
|
// Errno values.
|
||||||
|
const (
|
||||||
|
errnoERROR_IO_PENDING = 997
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||||
|
)
|
||||||
|
|
||||||
|
// errnoErr returns common boxed Errno values, to prevent
|
||||||
|
// allocations at runtime.
|
||||||
|
func errnoErr(e syscall.Errno) error {
|
||||||
|
switch e {
|
||||||
|
case 0:
|
||||||
|
return nil
|
||||||
|
case errnoERROR_IO_PENDING:
|
||||||
|
return errERROR_IO_PENDING
|
||||||
|
}
|
||||||
|
// TODO: add more here, after collecting data on the common
|
||||||
|
// error values see on Windows. (perhaps when running
|
||||||
|
// all.bat?)
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
var (
|
||||||
|
modnci = windows.NewLazySystemDLL("nci.dll")
|
||||||
|
|
||||||
|
procNciSetConnectionName = modnci.NewProc("NciSetConnectionName")
|
||||||
|
procNciGetConnectionName = modnci.NewProc("NciGetConnectionName")
|
||||||
|
)
|
||||||
|
|
||||||
|
func nciSetConnectionName(guid *windows.GUID, newName *uint16) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall(procNciSetConnectionName.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func nciGetConnectionName(guid *windows.GUID, destName *uint16, inDestNameBytes uint32, outDestNameBytes *uint32) (ret error) {
|
||||||
|
r0, _, _ := syscall.Syscall6(procNciGetConnectionName.Addr(), 4, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(destName)), uintptr(inDestNameBytes), uintptr(unsafe.Pointer(outDestNameBytes)), 0, 0)
|
||||||
|
if r0 != 0 {
|
||||||
|
ret = syscall.Errno(r0)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package netshell
|
|
||||||
|
|
||||||
import (
|
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
modnetshell = windows.NewLazySystemDLL("netshell.dll")
|
|
||||||
procHrRenameConnection = modnetshell.NewProc("HrRenameConnection")
|
|
||||||
)
|
|
||||||
|
|
||||||
func HrRenameConnection(guid *windows.GUID, newName *uint16) (err error) {
|
|
||||||
err = procHrRenameConnection.Find()
|
|
||||||
if err != nil {
|
|
||||||
// Missing from servercore, so we can't presume it's always there.
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
ret, _, _ := syscall.Syscall(procHrRenameConnection.Addr(), 2, uintptr(unsafe.Pointer(guid)), uintptr(unsafe.Pointer(newName)), 0)
|
|
||||||
if ret != 0 {
|
|
||||||
err = syscall.Errno(ret)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go
|
||||||
|
|||||||
117
tun/wintun/ring_windows.go
Normal file
117
tun/wintun/ring_windows.go
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
/* SPDX-License-Identifier: MIT
|
||||||
|
*
|
||||||
|
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package wintun
|
||||||
|
|
||||||
|
import (
|
||||||
|
"runtime"
|
||||||
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
PacketAlignment = 4 // Number of bytes packets are aligned to in rings
|
||||||
|
PacketSizeMax = 0xffff // Maximum packet size
|
||||||
|
PacketCapacity = 0x800000 // Ring capacity, 8MiB
|
||||||
|
PacketTrailingSize = uint32(unsafe.Sizeof(PacketHeader{})) + ((PacketSizeMax + (PacketAlignment - 1)) &^ (PacketAlignment - 1)) - PacketAlignment
|
||||||
|
ioctlRegisterRings = (51820 << 16) | (0x970 << 2) | 0 /*METHOD_BUFFERED*/ | (0x3 /*FILE_READ_DATA | FILE_WRITE_DATA*/ << 14)
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketHeader struct {
|
||||||
|
Size uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
type Packet struct {
|
||||||
|
PacketHeader
|
||||||
|
Data [PacketSizeMax]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type Ring struct {
|
||||||
|
Head uint32
|
||||||
|
Tail uint32
|
||||||
|
Alertable int32
|
||||||
|
Data [PacketCapacity + PacketTrailingSize]byte
|
||||||
|
}
|
||||||
|
|
||||||
|
type RingDescriptor struct {
|
||||||
|
Send, Receive struct {
|
||||||
|
Size uint32
|
||||||
|
Ring *Ring
|
||||||
|
TailMoved windows.Handle
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap returns value modulo ring capacity
|
||||||
|
func (rb *Ring) Wrap(value uint32) uint32 {
|
||||||
|
return value & (PacketCapacity - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Aligns a packet size to PacketAlignment
|
||||||
|
func PacketAlign(size uint32) uint32 {
|
||||||
|
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
|
||||||
|
descriptor = new(RingDescriptor)
|
||||||
|
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
descriptor.free()
|
||||||
|
descriptor = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
|
||||||
|
descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
|
||||||
|
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
|
||||||
|
descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
|
||||||
|
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
||||||
|
if err != nil {
|
||||||
|
windows.CloseHandle(descriptor.Send.TailMoved)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (descriptor *RingDescriptor) free() {
|
||||||
|
if descriptor.Send.Ring != nil {
|
||||||
|
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
|
||||||
|
descriptor.Send.Ring = nil
|
||||||
|
descriptor.Receive.Ring = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (descriptor *RingDescriptor) Close() {
|
||||||
|
if descriptor.Send.TailMoved != 0 {
|
||||||
|
windows.CloseHandle(descriptor.Send.TailMoved)
|
||||||
|
descriptor.Send.TailMoved = 0
|
||||||
|
}
|
||||||
|
if descriptor.Send.TailMoved != 0 {
|
||||||
|
windows.CloseHandle(descriptor.Receive.TailMoved)
|
||||||
|
descriptor.Receive.TailMoved = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (wintun *Interface) Register(descriptor *RingDescriptor) (windows.Handle, error) {
|
||||||
|
handle, err := wintun.handle()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
var bytesReturned uint32
|
||||||
|
err = windows.DeviceIoControl(handle, ioctlRegisterRings, (*byte)(unsafe.Pointer(descriptor)), uint32(unsafe.Sizeof(*descriptor)), nil, 0, &bytesReturned, nil)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return handle, nil
|
||||||
|
}
|
||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package setupapi
|
package setupapi
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsetupapi_windows.go setupapi_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -22,7 +21,7 @@ import (
|
|||||||
func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) {
|
func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName string) (deviceInfoSet DevInfo, err error) {
|
||||||
var machineNameUTF16 *uint16
|
var machineNameUTF16 *uint16
|
||||||
if machineName != "" {
|
if machineName != "" {
|
||||||
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName)
|
machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -40,8 +39,8 @@ func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailD
|
|||||||
return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data)
|
return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
|
// DeviceInfoListDetail method retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
|
||||||
func (deviceInfoSet DevInfo) GetDeviceInfoListDetail() (*DevInfoListDetailData, error) {
|
func (deviceInfoSet DevInfo) DeviceInfoListDetail() (*DevInfoListDetailData, error) {
|
||||||
return SetupDiGetDeviceInfoListDetail(deviceInfoSet)
|
return SetupDiGetDeviceInfoListDetail(deviceInfoSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,14 +48,14 @@ func (deviceInfoSet DevInfo) GetDeviceInfoListDetail() (*DevInfoListDetailData,
|
|||||||
|
|
||||||
// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set.
|
// SetupDiCreateDeviceInfo function creates a new device information element and adds it as a new member to the specified device information set.
|
||||||
func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) {
|
func SetupDiCreateDeviceInfo(deviceInfoSet DevInfo, deviceName string, classGUID *windows.GUID, deviceDescription string, hwndParent uintptr, creationFlags DICD) (deviceInfoData *DevInfoData, err error) {
|
||||||
deviceNameUTF16, err := syscall.UTF16PtrFromString(deviceName)
|
deviceNameUTF16, err := windows.UTF16PtrFromString(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
var deviceDescriptionUTF16 *uint16
|
var deviceDescriptionUTF16 *uint16
|
||||||
if deviceDescription != "" {
|
if deviceDescription != "" {
|
||||||
deviceDescriptionUTF16, err = syscall.UTF16PtrFromString(deviceDescription)
|
deviceDescriptionUTF16, err = windows.UTF16PtrFromString(deviceDescription)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -135,8 +134,8 @@ func SetupDiGetSelectedDriver(deviceInfoSet DevInfo, deviceInfoData *DevInfoData
|
|||||||
return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data)
|
return data, setupDiGetSelectedDriver(deviceInfoSet, deviceInfoData, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSelectedDriver method retrieves the selected driver for a device information set or a particular device information element.
|
// SelectedDriver method retrieves the selected driver for a device information set or a particular device information element.
|
||||||
func (deviceInfoSet DevInfo) GetSelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) {
|
func (deviceInfoSet DevInfo) SelectedDriver(deviceInfoData *DevInfoData) (*DrvInfoData, error) {
|
||||||
return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData)
|
return SetupDiGetSelectedDriver(deviceInfoSet, deviceInfoData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -151,38 +150,25 @@ func (deviceInfoSet DevInfo) SetSelectedDriver(deviceInfoData *DevInfoData, driv
|
|||||||
|
|
||||||
// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set.
|
// SetupDiGetDriverInfoDetail function retrieves driver information detail for a device information set or a particular device information element in the device information set.
|
||||||
func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
|
func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
|
||||||
const bufCapacity = 0x800
|
reqSize := uint32(2048)
|
||||||
buf := [bufCapacity]byte{}
|
for {
|
||||||
var bufLen uint32
|
buf := make([]byte, reqSize)
|
||||||
|
|
||||||
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
|
|
||||||
data.size = sizeofDrvInfoDetailData
|
|
||||||
|
|
||||||
err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufCapacity, &bufLen)
|
|
||||||
if err == nil {
|
|
||||||
// The buffer was was sufficiently big.
|
|
||||||
data.size = bufLen
|
|
||||||
return data, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_INSUFFICIENT_BUFFER {
|
|
||||||
// The buffer was too small. Now that we got the required size, create another one big enough and retry.
|
|
||||||
buf := make([]byte, bufLen)
|
|
||||||
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
|
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
|
||||||
data.size = sizeofDrvInfoDetailData
|
data.size = sizeofDrvInfoDetailData
|
||||||
|
err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, uint32(len(buf)), &reqSize)
|
||||||
err = setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufLen, &bufLen)
|
if err == windows.ERROR_INSUFFICIENT_BUFFER {
|
||||||
if err == nil {
|
continue
|
||||||
data.size = bufLen
|
|
||||||
return data, nil
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
data.size = reqSize
|
||||||
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set.
|
// DriverInfoDetail method retrieves driver information detail for a device information set or a particular device information element in the device information set.
|
||||||
func (deviceInfoSet DevInfo) GetDriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
|
func (deviceInfoSet DevInfo) DriverInfoDetail(deviceInfoData *DevInfoData, driverInfoData *DrvInfoData) (*DrvInfoDetailData, error) {
|
||||||
return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData)
|
return SetupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -199,14 +185,14 @@ func (deviceInfoSet DevInfo) DestroyDriverInfoList(deviceInfoData *DevInfoData,
|
|||||||
func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) {
|
func SetupDiGetClassDevsEx(classGUID *windows.GUID, enumerator string, hwndParent uintptr, flags DIGCF, deviceInfoSet DevInfo, machineName string) (handle DevInfo, err error) {
|
||||||
var enumeratorUTF16 *uint16
|
var enumeratorUTF16 *uint16
|
||||||
if enumerator != "" {
|
if enumerator != "" {
|
||||||
enumeratorUTF16, err = syscall.UTF16PtrFromString(enumerator)
|
enumeratorUTF16, err = windows.UTF16PtrFromString(enumerator)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
var machineNameUTF16 *uint16
|
var machineNameUTF16 *uint16
|
||||||
if machineName != "" {
|
if machineName != "" {
|
||||||
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName)
|
machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -239,24 +225,19 @@ func (deviceInfoSet DevInfo) OpenDevRegKey(DeviceInfoData *DevInfoData, Scope DI
|
|||||||
|
|
||||||
// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
|
// SetupDiGetDeviceRegistryProperty function retrieves a specified Plug and Play device property.
|
||||||
func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) {
|
func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, property SPDRP) (value interface{}, err error) {
|
||||||
buf := make([]byte, 0x100)
|
reqSize := uint32(256)
|
||||||
var dataType, bufLen uint32
|
for {
|
||||||
err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(cap(buf)), &bufLen)
|
var dataType uint32
|
||||||
if err == nil {
|
buf := make([]byte, reqSize)
|
||||||
// The buffer was sufficiently big.
|
err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(len(buf)), &reqSize)
|
||||||
return getRegistryValue(buf[:bufLen], dataType)
|
if err == windows.ERROR_INSUFFICIENT_BUFFER {
|
||||||
}
|
continue
|
||||||
|
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_INSUFFICIENT_BUFFER {
|
|
||||||
// The buffer was too small. Now that we got the required size, create another one big enough and retry.
|
|
||||||
buf = make([]byte, bufLen)
|
|
||||||
err = setupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, &dataType, &buf[0], uint32(cap(buf)), &bufLen)
|
|
||||||
if err == nil {
|
|
||||||
return getRegistryValue(buf[:bufLen], dataType)
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
return getRegistryValue(buf[:reqSize], dataType)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
|
func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
|
||||||
@@ -323,8 +304,8 @@ func wcslen(str []uint16) int {
|
|||||||
return len(str)
|
return len(str)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceRegistryProperty method retrieves a specified Plug and Play device property.
|
// DeviceRegistryProperty method retrieves a specified Plug and Play device property.
|
||||||
func (deviceInfoSet DevInfo) GetDeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) {
|
func (deviceInfoSet DevInfo) DeviceRegistryProperty(deviceInfoData *DevInfoData, property SPDRP) (interface{}, error) {
|
||||||
return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property)
|
return SetupDiGetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -340,6 +321,7 @@ func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoDa
|
|||||||
return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers)
|
return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetDeviceRegistryPropertyString method sets a Plug and Play device property string for a device.
|
||||||
func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error {
|
func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error {
|
||||||
str16, err := windows.UTF16FromString(str)
|
str16, err := windows.UTF16FromString(str)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -360,16 +342,39 @@ func SetupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInf
|
|||||||
return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params)
|
return params, setupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData, params)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetDeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element.
|
// DeviceInstallParams method retrieves device installation parameters for a device information set or a particular device information element.
|
||||||
func (deviceInfoSet DevInfo) GetDeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) {
|
func (deviceInfoSet DevInfo) DeviceInstallParams(deviceInfoData *DevInfoData) (*DevInstallParams, error) {
|
||||||
return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData)
|
return SetupDiGetDeviceInstallParams(deviceInfoSet, deviceInfoData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//sys setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) = setupapi.SetupDiGetDeviceInstanceIdW
|
||||||
|
|
||||||
|
// SetupDiGetDeviceInstanceId function retrieves the instance ID of the device.
|
||||||
|
func SetupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (string, error) {
|
||||||
|
reqSize := uint32(1024)
|
||||||
|
for {
|
||||||
|
buf := make([]uint16, reqSize)
|
||||||
|
err := setupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData, &buf[0], uint32(len(buf)), &reqSize)
|
||||||
|
if err == windows.ERROR_INSUFFICIENT_BUFFER {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return windows.UTF16ToString(buf), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeviceInstanceID method retrieves the instance ID of the device.
|
||||||
|
func (deviceInfoSet DevInfo) DeviceInstanceID(deviceInfoData *DevInfoData) (string, error) {
|
||||||
|
return SetupDiGetDeviceInstanceId(deviceInfoSet, deviceInfoData)
|
||||||
|
}
|
||||||
|
|
||||||
// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element.
|
// SetupDiGetClassInstallParams function retrieves class installation parameters for a device information set or a particular device information element.
|
||||||
//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW
|
//sys SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) = setupapi.SetupDiGetClassInstallParamsW
|
||||||
|
|
||||||
// GetClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element.
|
// ClassInstallParams method retrieves class installation parameters for a device information set or a particular device information element.
|
||||||
func (deviceInfoSet DevInfo) GetClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error {
|
func (deviceInfoSet DevInfo) ClassInstallParams(deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) error {
|
||||||
return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize)
|
return SetupDiGetClassInstallParams(deviceInfoSet, deviceInfoData, classInstallParams, classInstallParamsSize, requiredSize)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -396,7 +401,7 @@ func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (cl
|
|||||||
|
|
||||||
var machineNameUTF16 *uint16
|
var machineNameUTF16 *uint16
|
||||||
if machineName != "" {
|
if machineName != "" {
|
||||||
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName)
|
machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -415,39 +420,31 @@ func SetupDiClassNameFromGuidEx(classGUID *windows.GUID, machineName string) (cl
|
|||||||
|
|
||||||
// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer.
|
// SetupDiClassGuidsFromNameEx function retrieves the GUIDs associated with the specified class name. This resulting list contains the classes currently installed on a local or remote computer.
|
||||||
func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) {
|
func SetupDiClassGuidsFromNameEx(className string, machineName string) ([]windows.GUID, error) {
|
||||||
classNameUTF16, err := syscall.UTF16PtrFromString(className)
|
classNameUTF16, err := windows.UTF16PtrFromString(className)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
const bufCapacity = 4
|
|
||||||
var buf [bufCapacity]windows.GUID
|
|
||||||
var bufLen uint32
|
|
||||||
|
|
||||||
var machineNameUTF16 *uint16
|
var machineNameUTF16 *uint16
|
||||||
if machineName != "" {
|
if machineName != "" {
|
||||||
machineNameUTF16, err = syscall.UTF16PtrFromString(machineName)
|
machineNameUTF16, err = windows.UTF16PtrFromString(machineName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], bufCapacity, &bufLen, machineNameUTF16, 0)
|
reqSize := uint32(4)
|
||||||
if err == nil {
|
for {
|
||||||
// The GUID array was sufficiently big. Return its slice.
|
buf := make([]windows.GUID, reqSize)
|
||||||
return buf[:bufLen], nil
|
err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], uint32(len(buf)), &reqSize, machineNameUTF16, 0)
|
||||||
}
|
if err == windows.ERROR_INSUFFICIENT_BUFFER {
|
||||||
|
continue
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_INSUFFICIENT_BUFFER {
|
|
||||||
// The GUID array was too small. Now that we got the required size, create another one big enough and retry.
|
|
||||||
buf := make([]windows.GUID, bufLen)
|
|
||||||
err = setupDiClassGuidsFromNameEx(classNameUTF16, &buf[0], bufLen, &bufLen, machineNameUTF16, 0)
|
|
||||||
if err == nil {
|
|
||||||
return buf[:bufLen], nil
|
|
||||||
}
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return buf[:reqSize], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice
|
//sys setupDiGetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData) (err error) = setupapi.SetupDiGetSelectedDevice
|
||||||
@@ -460,8 +457,8 @@ func SetupDiGetSelectedDevice(deviceInfoSet DevInfo) (*DevInfoData, error) {
|
|||||||
return data, setupDiGetSelectedDevice(deviceInfoSet, data)
|
return data, setupDiGetSelectedDevice(deviceInfoSet, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetSelectedDevice method retrieves the selected device information element in a device information set.
|
// SelectedDevice method retrieves the selected device information element in a device information set.
|
||||||
func (deviceInfoSet DevInfo) GetSelectedDevice() (*DevInfoData, error) {
|
func (deviceInfoSet DevInfo) SelectedDevice() (*DevInfoData, error) {
|
||||||
return SetupDiGetSelectedDevice(deviceInfoSet)
|
return SetupDiGetSelectedDevice(deviceInfoSet)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -472,3 +469,38 @@ func (deviceInfoSet DevInfo) GetSelectedDevice() (*DevInfoData, error) {
|
|||||||
func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error {
|
func (deviceInfoSet DevInfo) SetSelectedDevice(deviceInfoData *DevInfoData) error {
|
||||||
return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData)
|
return SetupDiSetSelectedDevice(deviceInfoSet, deviceInfoData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//sys cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_List_SizeW
|
||||||
|
//sys cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) = CfgMgr32.CM_Get_Device_Interface_ListW
|
||||||
|
|
||||||
|
func CM_Get_Device_Interface_List(deviceID string, interfaceClass *windows.GUID, flags uint32) ([]string, error) {
|
||||||
|
deviceID16, err := windows.UTF16PtrFromString(deviceID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var buf []uint16
|
||||||
|
var buflen uint32
|
||||||
|
for {
|
||||||
|
if ret := cm_Get_Device_Interface_List_Size(&buflen, interfaceClass, deviceID16, flags); ret != CR_SUCCESS {
|
||||||
|
return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
|
||||||
|
}
|
||||||
|
buf = make([]uint16, buflen)
|
||||||
|
if ret := cm_Get_Device_Interface_List(interfaceClass, deviceID16, &buf[0], buflen, flags); ret == CR_SUCCESS {
|
||||||
|
break
|
||||||
|
} else if ret != CR_BUFFER_SMALL {
|
||||||
|
return nil, fmt.Errorf("CfgMgr error: 0x%x", ret)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var interfaces []string
|
||||||
|
for i := 0; i < len(buf); {
|
||||||
|
j := i + wcslen(buf[i:])
|
||||||
|
if i < j {
|
||||||
|
interfaces = append(interfaces, windows.UTF16ToString(buf[i:j]))
|
||||||
|
}
|
||||||
|
i = j + 1
|
||||||
|
}
|
||||||
|
if interfaces == nil {
|
||||||
|
return nil, fmt.Errorf("no interfaces found")
|
||||||
|
}
|
||||||
|
return interfaces, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -8,11 +8,9 @@ package setupapi
|
|||||||
import (
|
import (
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.zx2c4.com/wireguard/tun/wintun/guid"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
|
var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0x11ce, Data4: [8]byte{0xbf, 0xc1, 0x08, 0x00, 0x2b, 0xe1, 0x03, 0x18}}
|
||||||
@@ -24,24 +22,24 @@ func init() {
|
|||||||
|
|
||||||
func TestSetupDiCreateDeviceInfoListEx(t *testing.T) {
|
func TestSetupDiCreateDeviceInfoListEx(t *testing.T) {
|
||||||
devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
|
devInfoList, err := SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
|
||||||
if err == nil {
|
if err != nil {
|
||||||
devInfoList.Close()
|
|
||||||
} else {
|
|
||||||
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
|
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
|
||||||
|
} else {
|
||||||
|
devInfoList.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
|
devInfoList, err = SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, computerName)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
devInfoList.Close()
|
|
||||||
} else {
|
|
||||||
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
|
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx: %s", err.Error())
|
||||||
|
} else {
|
||||||
|
devInfoList.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "")
|
devInfoList, err = SetupDiCreateDeviceInfoListEx(nil, 0, "")
|
||||||
if err == nil {
|
if err != nil {
|
||||||
devInfoList.Close()
|
|
||||||
} else {
|
|
||||||
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error())
|
t.Errorf("Error calling SetupDiCreateDeviceInfoListEx(nil): %s", err.Error())
|
||||||
|
} else {
|
||||||
|
devInfoList.Close()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +50,7 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer devInfoList.Close()
|
defer devInfoList.Close()
|
||||||
|
|
||||||
data, err := devInfoList.GetDeviceInfoListDetail()
|
data, err := devInfoList.DeviceInfoListDetail()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
|
||||||
} else {
|
} else {
|
||||||
@@ -64,7 +62,7 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
|
|||||||
t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle")
|
t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.GetRemoteMachineName() != "" {
|
if data.RemoteMachineName() != "" {
|
||||||
t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name")
|
t.Error("SetupDiGetDeviceInfoListDetail returned non-NULL remote machine name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -75,7 +73,7 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
|
|||||||
}
|
}
|
||||||
defer devInfoList.Close()
|
defer devInfoList.Close()
|
||||||
|
|
||||||
data, err = devInfoList.GetDeviceInfoListDetail()
|
data, err = devInfoList.DeviceInfoListDetail()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceInfoListDetail: %s", err.Error())
|
||||||
} else {
|
} else {
|
||||||
@@ -87,14 +85,14 @@ func TestSetupDiGetDeviceInfoListDetail(t *testing.T) {
|
|||||||
t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle")
|
t.Error("SetupDiGetDeviceInfoListDetail returned NULL remote machine handle")
|
||||||
}
|
}
|
||||||
|
|
||||||
if data.GetRemoteMachineName() != computerName {
|
if data.RemoteMachineName() != computerName {
|
||||||
t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name")
|
t.Error("SetupDiGetDeviceInfoListDetail returned different remote machine name")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
data = &DevInfoListDetailData{}
|
data = &DevInfoListDetailData{}
|
||||||
data.SetRemoteMachineName("foobar")
|
data.SetRemoteMachineName("foobar")
|
||||||
if data.GetRemoteMachineName() != "foobar" {
|
if data.RemoteMachineName() != "foobar" {
|
||||||
t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ")
|
t.Error("DevInfoListDetailData.(Get|Set)RemoteMachineName() differ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -114,7 +112,7 @@ func TestSetupDiCreateDeviceInfo(t *testing.T) {
|
|||||||
devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID)
|
devInfoData, err := devInfoList.CreateDeviceInfo(deviceClassNetName, &deviceClassNetGUID, "This is a test device", 0, DICD_GENERATE_ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed.
|
// Access denied is expected, as the SetupDiCreateDeviceInfo() require elevation to succeed.
|
||||||
if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED {
|
if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_ACCESS_DENIED {
|
||||||
t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error())
|
t.Errorf("Error calling SetupDiCreateDeviceInfo: %s", err.Error())
|
||||||
}
|
}
|
||||||
} else if devInfoData.ClassGUID != deviceClassNetGUID {
|
} else if devInfoData.ClassGUID != deviceClassNetGUID {
|
||||||
@@ -132,7 +130,7 @@ func TestSetupDiEnumDeviceInfo(t *testing.T) {
|
|||||||
for i := 0; true; i++ {
|
for i := 0; true; i++ {
|
||||||
data, err := devInfoList.EnumDeviceInfo(i)
|
data, err := devInfoList.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -141,6 +139,11 @@ func TestSetupDiEnumDeviceInfo(t *testing.T) {
|
|||||||
if data.ClassGUID != deviceClassNetGUID {
|
if data.ClassGUID != deviceClassNetGUID {
|
||||||
t.Error("SetupDiEnumDeviceInfo returned different class GUID")
|
t.Error("SetupDiEnumDeviceInfo returned different class GUID")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
_, err = devInfoList.DeviceInstanceID(data)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("Error calling SetupDiGetDeviceInstanceId: %s", err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +157,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
|
|||||||
for i := 0; true; i++ {
|
for i := 0; true; i++ {
|
||||||
deviceData, err := devInfoList.EnumDeviceInfo(i)
|
deviceData, err := devInfoList.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -171,7 +174,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
|
|||||||
for j := 0; true; j++ {
|
for j := 0; true; j++ {
|
||||||
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j)
|
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -210,7 +213,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
|
|||||||
selectedDriverData = driverData
|
selectedDriverData = driverData
|
||||||
}
|
}
|
||||||
|
|
||||||
driverDetailData, err := devInfoList.GetDriverInfoDetail(deviceData, driverData)
|
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDriverInfoDetail: %s", err.Error())
|
||||||
}
|
}
|
||||||
@@ -218,10 +221,10 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
|
|||||||
if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") {
|
if driverDetailData.IsCompatible("foobar-aab6e3a4-144e-4786-88d3-6cec361e1edd") {
|
||||||
t.Error("Invalid HWID compatibitlity reported")
|
t.Error("Invalid HWID compatibitlity reported")
|
||||||
}
|
}
|
||||||
if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.GetHardwareID())) {
|
if !driverDetailData.IsCompatible(strings.ToUpper(driverDetailData.HardwareID())) {
|
||||||
t.Error("HWID compatibitlity missed")
|
t.Error("HWID compatibitlity missed")
|
||||||
}
|
}
|
||||||
a := driverDetailData.GetCompatIDs()
|
a := driverDetailData.CompatIDs()
|
||||||
for k := range a {
|
for k := range a {
|
||||||
if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) {
|
if !driverDetailData.IsCompatible(strings.ToUpper(a[k])) {
|
||||||
t.Error("HWID compatibitlity missed")
|
t.Error("HWID compatibitlity missed")
|
||||||
@@ -229,7 +232,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
selectedDriverData2, err := devInfoList.GetSelectedDriver(deviceData)
|
selectedDriverData2, err := devInfoList.SelectedDriver(deviceData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetSelectedDriver: %s", err.Error())
|
||||||
} else if *selectedDriverData != *selectedDriverData2 {
|
} else if *selectedDriverData != *selectedDriverData2 {
|
||||||
@@ -239,35 +242,35 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
|
|||||||
|
|
||||||
data := &DrvInfoData{}
|
data := &DrvInfoData{}
|
||||||
data.SetDescription("foobar")
|
data.SetDescription("foobar")
|
||||||
if data.GetDescription() != "foobar" {
|
if data.Description() != "foobar" {
|
||||||
t.Error("DrvInfoData.(Get|Set)Description() differ")
|
t.Error("DrvInfoData.(Get|Set)Description() differ")
|
||||||
}
|
}
|
||||||
data.SetMfgName("foobar")
|
data.SetMfgName("foobar")
|
||||||
if data.GetMfgName() != "foobar" {
|
if data.MfgName() != "foobar" {
|
||||||
t.Error("DrvInfoData.(Get|Set)MfgName() differ")
|
t.Error("DrvInfoData.(Get|Set)MfgName() differ")
|
||||||
}
|
}
|
||||||
data.SetProviderName("foobar")
|
data.SetProviderName("foobar")
|
||||||
if data.GetProviderName() != "foobar" {
|
if data.ProviderName() != "foobar" {
|
||||||
t.Error("DrvInfoData.(Get|Set)ProviderName() differ")
|
t.Error("DrvInfoData.(Get|Set)ProviderName() differ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestSetupDiGetClassDevsEx(t *testing.T) {
|
func TestSetupDiGetClassDevsEx(t *testing.T) {
|
||||||
devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName)
|
devInfoList, err := SetupDiGetClassDevsEx(&deviceClassNetGUID, "PCI", 0, DIGCF_PRESENT, DevInfo(0), computerName)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
devInfoList.Close()
|
|
||||||
} else {
|
|
||||||
t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetClassDevsEx: %s", err.Error())
|
||||||
|
} else {
|
||||||
|
devInfoList.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "")
|
devInfoList, err = SetupDiGetClassDevsEx(nil, "", 0, DIGCF_PRESENT, DevInfo(0), "")
|
||||||
if err == nil {
|
if err != nil {
|
||||||
devInfoList.Close()
|
if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
|
||||||
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
|
|
||||||
} else {
|
|
||||||
if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
|
|
||||||
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER")
|
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
devInfoList.Close()
|
||||||
|
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -281,7 +284,7 @@ func TestSetupDiOpenDevRegKey(t *testing.T) {
|
|||||||
for i := 0; true; i++ {
|
for i := 0; true; i++ {
|
||||||
data, err := devInfoList.EnumDeviceInfo(i)
|
data, err := devInfoList.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -305,47 +308,47 @@ func TestSetupDiGetDeviceRegistryProperty(t *testing.T) {
|
|||||||
for i := 0; true; i++ {
|
for i := 0; true; i++ {
|
||||||
data, err := devInfoList.EnumDeviceInfo(i)
|
data, err := devInfoList.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err := devInfoList.GetDeviceRegistryProperty(data, SPDRP_CLASS)
|
val, err := devInfoList.DeviceRegistryProperty(data, SPDRP_CLASS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASS): %s", err.Error())
|
||||||
} else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" {
|
} else if class, ok := val.(string); !ok || strings.ToLower(class) != "net" {
|
||||||
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"")
|
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASS) should return \"Net\"")
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_CLASSGUID)
|
val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CLASSGUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
|
||||||
} else if valStr, ok := val.(string); !ok {
|
} else if valStr, ok := val.(string); !ok {
|
||||||
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string")
|
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return string")
|
||||||
} else {
|
} else {
|
||||||
classGUID, err := guid.FromString(valStr)
|
classGUID, err := windows.GUIDFromString(valStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
|
t.Errorf("Error parsing GUID returned by SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID): %s", err.Error())
|
||||||
} else if *classGUID != deviceClassNetGUID {
|
} else if classGUID != deviceClassNetGUID {
|
||||||
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID)
|
t.Errorf("SetupDiGetDeviceRegistryProperty(SPDRP_CLASSGUID) should return %x", deviceClassNetGUID)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS)
|
val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Some devices have no SPDRP_COMPATIBLEIDS.
|
// Some devices have no SPDRP_COMPATIBLEIDS.
|
||||||
if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_DATA {
|
if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_DATA {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_CONFIGFLAGS)
|
val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_CONFIGFLAGS)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_CONFIGFLAGS): %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA)
|
val, err = devInfoList.DeviceRegistryProperty(data, SPDRP_DEVICE_POWER_DATA)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_DEVICE_POWER_DATA): %s", err.Error())
|
||||||
}
|
}
|
||||||
@@ -362,13 +365,13 @@ func TestSetupDiGetDeviceInstallParams(t *testing.T) {
|
|||||||
for i := 0; true; i++ {
|
for i := 0; true; i++ {
|
||||||
data, err := devInfoList.EnumDeviceInfo(i)
|
data, err := devInfoList.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = devInfoList.GetDeviceInstallParams(data)
|
_, err = devInfoList.DeviceInstallParams(data)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetDeviceInstallParams: %s", err.Error())
|
||||||
}
|
}
|
||||||
@@ -376,7 +379,7 @@ func TestSetupDiGetDeviceInstallParams(t *testing.T) {
|
|||||||
|
|
||||||
params := &DevInstallParams{}
|
params := &DevInstallParams{}
|
||||||
params.SetDriverPath("foobar")
|
params.SetDriverPath("foobar")
|
||||||
if params.GetDriverPath() != "foobar" {
|
if params.DriverPath() != "foobar" {
|
||||||
t.Error("DevInstallParams.(Get|Set)DriverPath() differ")
|
t.Error("DevInstallParams.(Get|Set)DriverPath() differ")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -397,12 +400,12 @@ func TestSetupDiClassNameFromGuidEx(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err = SetupDiClassNameFromGuidEx(nil, "")
|
_, err = SetupDiClassNameFromGuidEx(nil, "")
|
||||||
if err == nil {
|
if err != nil {
|
||||||
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail")
|
if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER {
|
||||||
} else {
|
|
||||||
if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER {
|
|
||||||
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER")
|
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -441,7 +444,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
|
|||||||
for i := 0; true; i++ {
|
for i := 0; true; i++ {
|
||||||
data, err := devInfoList.EnumDeviceInfo(i)
|
data, err := devInfoList.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
if errWin, ok := err.(windows.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
@@ -452,7 +455,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
|
|||||||
t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error())
|
t.Errorf("Error calling SetupDiSetSelectedDevice: %s", err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
data2, err := devInfoList.GetSelectedDevice()
|
data2, err := devInfoList.SelectedDevice()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error())
|
t.Errorf("Error calling SetupDiGetSelectedDevice: %s", err.Error())
|
||||||
} else if *data != *data2 {
|
} else if *data != *data2 {
|
||||||
@@ -461,12 +464,12 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
err = devInfoList.SetSelectedDevice(nil)
|
err = devInfoList.SetSelectedDevice(nil)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
t.Errorf("SetupDiSetSelectedDevice(nil) should fail")
|
if errWin, ok := err.(windows.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
|
||||||
} else {
|
|
||||||
if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
|
|
||||||
t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
|
t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
t.Errorf("SetupDiSetSelectedDevice(nil) should fail")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ package setupapi
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -58,7 +57,7 @@ type DevInfoData struct {
|
|||||||
_ uintptr
|
_ uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
|
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
|
||||||
type DevInfoListDetailData struct {
|
type DevInfoListDetailData struct {
|
||||||
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
|
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
|
||||||
ClassGUID windows.GUID
|
ClassGUID windows.GUID
|
||||||
@@ -66,12 +65,12 @@ type DevInfoListDetailData struct {
|
|||||||
remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16
|
remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DevInfoListDetailData) GetRemoteMachineName() string {
|
func (data *DevInfoListDetailData) RemoteMachineName() string {
|
||||||
return windows.UTF16ToString(data.remoteMachineName[:])
|
return windows.UTF16ToString(data.remoteMachineName[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error {
|
func (data *DevInfoListDetailData) SetRemoteMachineName(remoteMachineName string) error {
|
||||||
str, err := syscall.UTF16FromString(remoteMachineName)
|
str, err := windows.UTF16FromString(remoteMachineName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -138,12 +137,12 @@ type DevInstallParams struct {
|
|||||||
driverPath [windows.MAX_PATH]uint16
|
driverPath [windows.MAX_PATH]uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (params *DevInstallParams) GetDriverPath() string {
|
func (params *DevInstallParams) DriverPath() string {
|
||||||
return windows.UTF16ToString(params.driverPath[:])
|
return windows.UTF16ToString(params.driverPath[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (params *DevInstallParams) SetDriverPath(driverPath string) error {
|
func (params *DevInstallParams) SetDriverPath(driverPath string) error {
|
||||||
str, err := syscall.UTF16FromString(driverPath)
|
str, err := windows.UTF16FromString(driverPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -268,15 +267,34 @@ func MakeClassInstallHeader(installFunction DI_FUNCTION) *ClassInstallHeader {
|
|||||||
return hdr
|
return hdr
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DICS_STATE specifies values indicating a change in a device's state
|
||||||
|
type DICS_STATE uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
DICS_ENABLE DICS_STATE = 0x00000001 // The device is being enabled.
|
||||||
|
DICS_DISABLE DICS_STATE = 0x00000002 // The device is being disabled.
|
||||||
|
DICS_PROPCHANGE DICS_STATE = 0x00000003 // The properties of the device have changed.
|
||||||
|
DICS_START DICS_STATE = 0x00000004 // The device is being started (if the request is for the currently active hardware profile).
|
||||||
|
DICS_STOP DICS_STATE = 0x00000005 // The device is being stopped. The driver stack will be unloaded and the CSCONFIGFLAG_DO_NOT_START flag will be set for the device.
|
||||||
|
)
|
||||||
|
|
||||||
// DICS_FLAG specifies the scope of a device property change
|
// DICS_FLAG specifies the scope of a device property change
|
||||||
type DICS_FLAG uint32
|
type DICS_FLAG uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles
|
DICS_FLAG_GLOBAL DICS_FLAG = 0x00000001 // make change in all hardware profiles
|
||||||
DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only
|
DICS_FLAG_CONFIGSPECIFIC DICS_FLAG = 0x00000002 // make change in specified profile only
|
||||||
DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow
|
DICS_FLAG_CONFIGGENERAL DICS_FLAG = 0x00000004 // 1 or more hardware profile-specific changes to follow (obsolete)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// PropChangeParams is a structure corresponding to a DIF_PROPERTYCHANGE install function.
|
||||||
|
type PropChangeParams struct {
|
||||||
|
ClassInstallHeader ClassInstallHeader
|
||||||
|
StateChange DICS_STATE
|
||||||
|
Scope DICS_FLAG
|
||||||
|
HwProfile uint32
|
||||||
|
}
|
||||||
|
|
||||||
// DI_REMOVEDEVICE specifies the scope of the device removal
|
// DI_REMOVEDEVICE specifies the scope of the device removal
|
||||||
type DI_REMOVEDEVICE uint32
|
type DI_REMOVEDEVICE uint32
|
||||||
|
|
||||||
@@ -304,12 +322,12 @@ type DrvInfoData struct {
|
|||||||
DriverVersion uint64
|
DriverVersion uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoData) GetDescription() string {
|
func (data *DrvInfoData) Description() string {
|
||||||
return windows.UTF16ToString(data.description[:])
|
return windows.UTF16ToString(data.description[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoData) SetDescription(description string) error {
|
func (data *DrvInfoData) SetDescription(description string) error {
|
||||||
str, err := syscall.UTF16FromString(description)
|
str, err := windows.UTF16FromString(description)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -317,12 +335,12 @@ func (data *DrvInfoData) SetDescription(description string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoData) GetMfgName() string {
|
func (data *DrvInfoData) MfgName() string {
|
||||||
return windows.UTF16ToString(data.mfgName[:])
|
return windows.UTF16ToString(data.mfgName[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoData) SetMfgName(mfgName string) error {
|
func (data *DrvInfoData) SetMfgName(mfgName string) error {
|
||||||
str, err := syscall.UTF16FromString(mfgName)
|
str, err := windows.UTF16FromString(mfgName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -330,12 +348,12 @@ func (data *DrvInfoData) SetMfgName(mfgName string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoData) GetProviderName() string {
|
func (data *DrvInfoData) ProviderName() string {
|
||||||
return windows.UTF16ToString(data.providerName[:])
|
return windows.UTF16ToString(data.providerName[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoData) SetProviderName(providerName string) error {
|
func (data *DrvInfoData) SetProviderName(providerName string) error {
|
||||||
str, err := syscall.UTF16FromString(providerName)
|
str, err := windows.UTF16FromString(providerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -382,19 +400,19 @@ type DrvInfoDetailData struct {
|
|||||||
hardwareID [ANYSIZE_ARRAY]uint16
|
hardwareID [ANYSIZE_ARRAY]uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoDetailData) GetSectionName() string {
|
func (data *DrvInfoDetailData) SectionName() string {
|
||||||
return windows.UTF16ToString(data.sectionName[:])
|
return windows.UTF16ToString(data.sectionName[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoDetailData) GetInfFileName() string {
|
func (data *DrvInfoDetailData) InfFileName() string {
|
||||||
return windows.UTF16ToString(data.infFileName[:])
|
return windows.UTF16ToString(data.infFileName[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoDetailData) GetDrvDescription() string {
|
func (data *DrvInfoDetailData) DrvDescription() string {
|
||||||
return windows.UTF16ToString(data.drvDescription[:])
|
return windows.UTF16ToString(data.drvDescription[:])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoDetailData) GetHardwareID() string {
|
func (data *DrvInfoDetailData) HardwareID() string {
|
||||||
if data.compatIDsOffset > 1 {
|
if data.compatIDsOffset > 1 {
|
||||||
bufW := data.getBuf()
|
bufW := data.getBuf()
|
||||||
return windows.UTF16ToString(bufW[:wcslen(bufW)])
|
return windows.UTF16ToString(bufW[:wcslen(bufW)])
|
||||||
@@ -403,7 +421,7 @@ func (data *DrvInfoDetailData) GetHardwareID() string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
func (data *DrvInfoDetailData) GetCompatIDs() []string {
|
func (data *DrvInfoDetailData) CompatIDs() []string {
|
||||||
a := make([]string, 0)
|
a := make([]string, 0)
|
||||||
|
|
||||||
if data.compatIDsLength > 0 {
|
if data.compatIDsLength > 0 {
|
||||||
@@ -434,10 +452,10 @@ func (data *DrvInfoDetailData) getBuf() []uint16 {
|
|||||||
// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list.
|
// IsCompatible method tests if given hardware ID matches the driver or is listed on the compatible ID list.
|
||||||
func (data *DrvInfoDetailData) IsCompatible(hwid string) bool {
|
func (data *DrvInfoDetailData) IsCompatible(hwid string) bool {
|
||||||
hwidLC := strings.ToLower(hwid)
|
hwidLC := strings.ToLower(hwid)
|
||||||
if strings.ToLower(data.GetHardwareID()) == hwidLC {
|
if strings.ToLower(data.HardwareID()) == hwidLC {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
a := data.GetCompatIDs()
|
a := data.CompatIDs()
|
||||||
for i := range a {
|
for i := range a {
|
||||||
if strings.ToLower(a[i]) == hwidLC {
|
if strings.ToLower(a[i]) == hwidLC {
|
||||||
return true
|
return true
|
||||||
@@ -538,3 +556,13 @@ const (
|
|||||||
|
|
||||||
SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals
|
SPDRP_MAXIMUM_PROPERTY SPDRP = 0x00000025 // Upper bound on ordinals
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CR_SUCCESS = 0x0
|
||||||
|
CR_BUFFER_SMALL = 0x1a
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CM_GET_DEVICE_INTERFACE_LIST_PRESENT = 0 // only currently 'live' device interfaces
|
||||||
|
CM_GET_DEVICE_INTERFACE_LIST_ALL_DEVICES = 1 // all registered device interfaces, live or not
|
||||||
|
)
|
||||||
|
|||||||
@@ -38,32 +38,36 @@ func errnoErr(e syscall.Errno) error {
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
modsetupapi = windows.NewLazySystemDLL("setupapi.dll")
|
modsetupapi = windows.NewLazySystemDLL("setupapi.dll")
|
||||||
|
modCfgMgr32 = windows.NewLazySystemDLL("CfgMgr32.dll")
|
||||||
|
|
||||||
procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW")
|
procSetupDiCreateDeviceInfoListExW = modsetupapi.NewProc("SetupDiCreateDeviceInfoListExW")
|
||||||
procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW")
|
procSetupDiGetDeviceInfoListDetailW = modsetupapi.NewProc("SetupDiGetDeviceInfoListDetailW")
|
||||||
procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW")
|
procSetupDiCreateDeviceInfoW = modsetupapi.NewProc("SetupDiCreateDeviceInfoW")
|
||||||
procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo")
|
procSetupDiEnumDeviceInfo = modsetupapi.NewProc("SetupDiEnumDeviceInfo")
|
||||||
procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList")
|
procSetupDiDestroyDeviceInfoList = modsetupapi.NewProc("SetupDiDestroyDeviceInfoList")
|
||||||
procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList")
|
procSetupDiBuildDriverInfoList = modsetupapi.NewProc("SetupDiBuildDriverInfoList")
|
||||||
procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch")
|
procSetupDiCancelDriverInfoSearch = modsetupapi.NewProc("SetupDiCancelDriverInfoSearch")
|
||||||
procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW")
|
procSetupDiEnumDriverInfoW = modsetupapi.NewProc("SetupDiEnumDriverInfoW")
|
||||||
procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW")
|
procSetupDiGetSelectedDriverW = modsetupapi.NewProc("SetupDiGetSelectedDriverW")
|
||||||
procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW")
|
procSetupDiSetSelectedDriverW = modsetupapi.NewProc("SetupDiSetSelectedDriverW")
|
||||||
procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW")
|
procSetupDiGetDriverInfoDetailW = modsetupapi.NewProc("SetupDiGetDriverInfoDetailW")
|
||||||
procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList")
|
procSetupDiDestroyDriverInfoList = modsetupapi.NewProc("SetupDiDestroyDriverInfoList")
|
||||||
procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW")
|
procSetupDiGetClassDevsExW = modsetupapi.NewProc("SetupDiGetClassDevsExW")
|
||||||
procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller")
|
procSetupDiCallClassInstaller = modsetupapi.NewProc("SetupDiCallClassInstaller")
|
||||||
procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey")
|
procSetupDiOpenDevRegKey = modsetupapi.NewProc("SetupDiOpenDevRegKey")
|
||||||
procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW")
|
procSetupDiGetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiGetDeviceRegistryPropertyW")
|
||||||
procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW")
|
procSetupDiSetDeviceRegistryPropertyW = modsetupapi.NewProc("SetupDiSetDeviceRegistryPropertyW")
|
||||||
procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW")
|
procSetupDiGetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiGetDeviceInstallParamsW")
|
||||||
procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW")
|
procSetupDiGetDeviceInstanceIdW = modsetupapi.NewProc("SetupDiGetDeviceInstanceIdW")
|
||||||
procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW")
|
procSetupDiGetClassInstallParamsW = modsetupapi.NewProc("SetupDiGetClassInstallParamsW")
|
||||||
procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW")
|
procSetupDiSetDeviceInstallParamsW = modsetupapi.NewProc("SetupDiSetDeviceInstallParamsW")
|
||||||
procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW")
|
procSetupDiSetClassInstallParamsW = modsetupapi.NewProc("SetupDiSetClassInstallParamsW")
|
||||||
procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW")
|
procSetupDiClassNameFromGuidExW = modsetupapi.NewProc("SetupDiClassNameFromGuidExW")
|
||||||
procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice")
|
procSetupDiClassGuidsFromNameExW = modsetupapi.NewProc("SetupDiClassGuidsFromNameExW")
|
||||||
procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice")
|
procSetupDiGetSelectedDevice = modsetupapi.NewProc("SetupDiGetSelectedDevice")
|
||||||
|
procSetupDiSetSelectedDevice = modsetupapi.NewProc("SetupDiSetSelectedDevice")
|
||||||
|
procCM_Get_Device_Interface_List_SizeW = modCfgMgr32.NewProc("CM_Get_Device_Interface_List_SizeW")
|
||||||
|
procCM_Get_Device_Interface_ListW = modCfgMgr32.NewProc("CM_Get_Device_Interface_ListW")
|
||||||
)
|
)
|
||||||
|
|
||||||
func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
|
func setupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr, machineName *uint16, reserved uintptr) (handle DevInfo, err error) {
|
||||||
@@ -285,6 +289,18 @@ func setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInf
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupDiGetDeviceInstanceId(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, instanceId *uint16, instanceIdSize uint32, instanceIdRequiredSize *uint32) (err error) {
|
||||||
|
r1, _, e1 := syscall.Syscall6(procSetupDiGetDeviceInstanceIdW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(instanceId)), uintptr(instanceIdSize), uintptr(unsafe.Pointer(instanceIdRequiredSize)), 0)
|
||||||
|
if r1 == 0 {
|
||||||
|
if e1 != 0 {
|
||||||
|
err = errnoErr(e1)
|
||||||
|
} else {
|
||||||
|
err = syscall.EINVAL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) {
|
func SetupDiGetClassInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, classInstallParams *ClassInstallHeader, classInstallParamsSize uint32, requiredSize *uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0)
|
r1, _, e1 := syscall.Syscall6(procSetupDiGetClassInstallParamsW.Addr(), 5, uintptr(deviceInfoSet), uintptr(unsafe.Pointer(deviceInfoData)), uintptr(unsafe.Pointer(classInstallParams)), uintptr(classInstallParamsSize), uintptr(unsafe.Pointer(requiredSize)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
@@ -368,3 +384,15 @@ func SetupDiSetSelectedDevice(deviceInfoSet DevInfo, deviceInfoData *DevInfoData
|
|||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func cm_Get_Device_Interface_List_Size(len *uint32, interfaceClass *windows.GUID, deviceID *uint16, flags uint32) (ret uint32) {
|
||||||
|
r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_List_SizeW.Addr(), 4, uintptr(unsafe.Pointer(len)), uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(flags), 0, 0)
|
||||||
|
ret = uint32(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func cm_Get_Device_Interface_List(interfaceClass *windows.GUID, deviceID *uint16, buffer *uint16, bufferLen uint32, flags uint32) (ret uint32) {
|
||||||
|
r0, _, _ := syscall.Syscall6(procCM_Get_Device_Interface_ListW.Addr(), 5, uintptr(unsafe.Pointer(interfaceClass)), uintptr(unsafe.Pointer(deviceID)), uintptr(unsafe.Pointer(buffer)), uintptr(bufferLen), uintptr(flags), 0)
|
||||||
|
ret = uint32(r0)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user