35 Commits

Author SHA1 Message Date
Jason A. Donenfeld
583ebe99f1 version: bump snapshot 2019-05-17 10:28:04 +02:00
Jason A. Donenfeld
a6dd282600 makefile: do not show warning on non-linux 2019-05-17 10:27:51 +02:00
Simon Rozman
7d5f5bcc0d wintun: change acronyms to uppercase
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-17 10:22:34 +02:00
Jason A. Donenfeld
3bf41b06ae global: regroup all imports 2019-05-14 09:09:52 +02:00
Jason A. Donenfeld
3147f00089 wintun: registry: fix nits 2019-05-11 17:25:48 +02:00
Simon Rozman
6c1b66802f wintun: registry: revise value reading
- Make getStringValueRetry() reusable for reading any value type. This
  merges code from GetIntegerValueWait().
- expandString() >> toString() and extend to support REG_MULTI_SZ
  (to return first value of REG_MULTI_SZ). Furthermore, doing our own
  UTF-16 to UTF-8 conversion works around a bug in windows/registry's
  GetStringValue() non-zero terminated string handling.
- Provide toInteger() analogous to toString()
- GetStringValueWait() tolerates and reads REG_MULTI_SZ too now. It
  returns REG_MULTI_SZ[0], making GetFirstStringValueWait() redundant.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-11 17:14:37 +02:00
Jason A. Donenfeld
5669ed326f wintun: call HrRenameConnection in another thread 2019-05-10 21:31:37 +02:00
Jason A. Donenfeld
2d847a38a2 wintun: add LUID accessor 2019-05-10 21:30:23 +02:00
Jason A. Donenfeld
7a8553aef0 wintun: enumerate faster by using COMPATDRIVER instead of CLASSDRIVER 2019-05-10 20:30:59 +02:00
Jason A. Donenfeld
a6045ac042 wintun: destroy devinfolist after usage 2019-05-10 20:19:11 +02:00
Simon Rozman
1c92b48415 wintun: registry: replace REG_NOTIFY with NOTIFY
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-10 18:09:20 +02:00
Jason A. Donenfeld
c267965bf8 wintun: IpConfig is a MULTI_SZ, and fix errors 2019-05-10 18:06:49 +02:00
Jason A. Donenfeld
1bf1dadf15 wintun: poll for device key
It's actually pretty hard to guess where it is.
2019-05-10 17:34:03 +02:00
Jason A. Donenfeld
f9dcfccbb7 wintun: fix scope of error object 2019-05-10 16:59:24 +02:00
Simon Rozman
7e962a9932 wintun: wait for interface registry key on device creation
By using RegNotifyChangeKeyValue(). Also disable dead gateway detection.

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-10 16:43:58 +02:00
Jason A. Donenfeld
586112b5d7 conn: remove scope when sanity checking IP address format 2019-05-09 15:42:35 +02:00
Simon Rozman
dcb8f1aa6b wintun: fix GUID leading zero padding
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-05-09 12:16:21 +02:00
Jason A. Donenfeld
b16b0e4cf7 mod: update deps 2019-05-03 09:37:29 +02:00
Jason A. Donenfeld
81ca08f1b3 setupapi: safer aliasing of slice types 2019-05-03 09:34:00 +02:00
Jason A. Donenfeld
2e988467c2 wintun: work around GetInterface staleness bug 2019-05-03 00:42:36 +02:00
Jason A. Donenfeld
46dbf54040 wintun: don't retry when not creating
The only time we're trying to counteract the race condition is when
we're creating a driver. When we're simply looking up all drivers, it
doesn't make sense to retry.
2019-05-02 23:53:15 +02:00
Jason A. Donenfeld
247e14693a wintun: try harder to open registry key
This sucks. Can we please find a deterministic way of doing this
instead?
2019-04-29 14:00:49 +02:00
Jason A. Donenfeld
3945a299ff go.mod: use vendored winio 2019-04-29 08:09:38 +02:00
Jason A. Donenfeld
bb42ec7d18 tun: freebsd: work around numerous kernel panics on shutdown
There are numerous race conditions. But even this will crash it:

while true; do ifconfig tun0 create; ifconfig tun0 destroy; done

It seems like LLv6 is related, which we're not using anyway, so
explicitly disable it on the interface.
2019-04-23 18:00:23 +09:00
Simon Rozman
f1dc167901 setupapi: Fix struct size mismatches
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-19 10:08:11 +02:00
Jason A. Donenfeld
c7a26dfef3 setupapi: actually fix padding by rounding up to sizeof(void*) 2019-04-19 10:19:00 +09:00
Jason A. Donenfeld
d024393335 tun: darwin: write routeSocket variable in helper
Otherwise the race detector "complains".
2019-04-19 07:53:19 +09:00
Jason A. Donenfeld
d9078fe772 main: revise warnings 2019-04-19 07:48:09 +09:00
Jason A. Donenfeld
d3dd991e4e device: send: check packet length before freeing element 2019-04-18 23:23:03 +09:00
Simon Rozman
5811447b38 setupapi: Revise DrvInfoDetailData struct size calculation
Go adds trailing padding to DrvInfoDetailData struct in GOARCH=386 which
confuses SetupAPI expecting exactly sizeof(SP_DRVINFO_DETAIL_DATA).

Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-18 10:39:22 +02:00
Jason A. Donenfeld
e0a8c22aa6 windows: use proper constants from updated x/sys 2019-04-13 02:02:02 +02:00
Jason A. Donenfeld
0b77bf78cd conn: linux: RTA_MARK has moved to x/sys 2019-04-13 02:01:20 +02:00
Simon Rozman
ef5f3ad80a tun: windows: Adopt new error codes returned by Wintun
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-11 19:38:11 +02:00
Simon Rozman
a291fdd746 tun: windows: do not sleep after OPERATION_ABORTED on write
Signed-off-by: Simon Rozman <simon@rozman.si>
2019-04-11 19:37:04 +02:00
Jason A. Donenfeld
d50e390904 main_windows: use proper version constant 2019-04-09 10:45:40 +02:00
46 changed files with 909 additions and 299 deletions

View File

@@ -8,7 +8,7 @@ all: generate-version-and-build
ifeq ($(shell go env GOOS)|$(wildcard .git),linux|) ifeq ($(shell go env GOOS)|$(wildcard .git),linux|)
$(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.) $(error Do not build this for Linux. Instead use the Linux kernel module. See wireguard.com/install/ for more info.)
else else ifeq ($(shell go env GOOS),linux)
ireallywantobuildon_linux.go: ireallywantobuildon_linux.go:
@printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2 @printf "WARNING: This software is meant for use on non-Linux\nsystems. For Linux, please use the kernel module\ninstead. See wireguard.com/install/ for more info.\n\n" >&2
@printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@" @printf 'package main\nconst UseTheKernelModuleInstead = 0xdeadbabe\n' > "$@"

View File

@@ -7,8 +7,9 @@ package device
import ( import (
"encoding/binary" "encoding/binary"
"golang.org/x/sys/windows"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
const ( const (

View File

@@ -7,9 +7,11 @@ package device
import ( import (
"errors" "errors"
"net"
"strings"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
"net"
) )
const ( const (
@@ -41,13 +43,18 @@ type Endpoint interface {
} }
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
host, _, err := net.SplitHostPort(s) host, _, err := net.SplitHostPort(s)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
// trying to make sure with a small sanity test that this is a real IP address and
// not something that's likely to incur DNS lookups.
host = host[:i]
}
if ip := net.ParseIP(host); ip == nil { if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host) return nil, errors.New("Failed to parse IP address: " + host)
} }

View File

@@ -18,13 +18,14 @@ package device
import ( import (
"errors" "errors"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
const ( const (
@@ -719,7 +720,7 @@ func (bind *nativeBind) routineRouteListener(device *Device) {
peer.endpoint.(*NativeEndpoint).src4().src, peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{ unix.RtAttr{
Len: 8, Len: 8,
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix Type: unix.RTA_MARK,
}, },
uint32(bind.lastMark), uint32(bind.lastMark),
} }

View File

@@ -8,10 +8,11 @@ package device
import ( import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"sync" "sync"
"time" "time"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
) )
type CookieChecker struct { type CookieChecker struct {

View File

@@ -6,12 +6,13 @@
package device package device
import ( import (
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/tun"
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/tun"
) )
const ( const (

View File

@@ -7,8 +7,9 @@ package device
import ( import (
"encoding/hex" "encoding/hex"
"golang.org/x/crypto/blake2s"
"testing" "testing"
"golang.org/x/crypto/blake2s"
) )
type KDFTest struct { type KDFTest struct {

View File

@@ -7,9 +7,10 @@ package device
import ( import (
"crypto/cipher" "crypto/cipher"
"golang.zx2c4.com/wireguard/replay"
"sync" "sync"
"time" "time"
"golang.zx2c4.com/wireguard/replay"
) )
/* Due to limitations in Go and /x/crypto there is currently /* Due to limitations in Go and /x/crypto there is currently

View File

@@ -8,8 +8,9 @@
package device package device
import ( import (
"golang.org/x/sys/unix"
"runtime" "runtime"
"golang.org/x/sys/unix"
) )
var fwmarkIoctl int var fwmarkIoctl int

View File

@@ -9,9 +9,10 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/subtle" "crypto/subtle"
"hash"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519" "golang.org/x/crypto/curve25519"
"hash"
) )
/* KDF related functions. /* KDF related functions.

View File

@@ -7,12 +7,13 @@ package device
import ( import (
"errors" "errors"
"sync"
"time"
"golang.org/x/crypto/blake2s" "golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305" "golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n" "golang.zx2c4.com/wireguard/tai64n"
"sync"
"time"
) )
const ( const (

View File

@@ -9,6 +9,7 @@ import (
"crypto/subtle" "crypto/subtle"
"encoding/hex" "encoding/hex"
"errors" "errors"
"golang.org/x/crypto/chacha20poly1305" "golang.org/x/crypto/chacha20poly1305"
) )

View File

@@ -8,14 +8,15 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"strconv" "strconv"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
type QueueHandshakeElement struct { type QueueHandshakeElement struct {

View File

@@ -8,13 +8,14 @@ package device
import ( import (
"bytes" "bytes"
"encoding/binary" "encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net" "net"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
) )
/* Outbound flow /* Outbound flow
@@ -601,6 +602,9 @@ func (peer *Peer) RoutineSequentialSender() {
length := uint64(len(elem.packet)) length := uint64(len(elem.packet))
err := peer.SendBuffer(elem.packet) err := peer.SendBuffer(elem.packet)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent()
}
device.PutMessageBuffer(elem.buffer) device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem) device.PutOutboundElement(elem)
if err != nil { if err != nil {
@@ -609,9 +613,6 @@ func (peer *Peer) RoutineSequentialSender() {
} }
atomic.AddUint64(&peer.stats.txBytes, length) atomic.AddUint64(&peer.stats.txBytes, length)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent()
}
peer.keepKeyFreshSending() peer.keepKeyFreshSending()
} }
} }

View File

@@ -6,8 +6,9 @@
package device package device
import ( import (
"golang.zx2c4.com/wireguard/tun"
"sync/atomic" "sync/atomic"
"golang.zx2c4.com/wireguard/tun"
) )
const DefaultMTU = 1420 const DefaultMTU = 1420

View File

@@ -8,13 +8,14 @@ package device
import ( import (
"bufio" "bufio"
"fmt" "fmt"
"golang.zx2c4.com/wireguard/ipc"
"io" "io"
"net" "net"
"strconv" "strconv"
"strings" "strings"
"sync/atomic" "sync/atomic"
"time" "time"
"golang.zx2c4.com/wireguard/ipc"
) )
type IPCError struct { type IPCError struct {

View File

@@ -1,3 +1,3 @@
package device package device
const WireGuardGoVersion = "0.0.20190409" const WireGuardGoVersion = "0.0.20190517"

11
go.mod
View File

@@ -1,8 +1,13 @@
module golang.zx2c4.com/wireguard module golang.zx2c4.com/wireguard
go 1.12
require ( require (
github.com/Microsoft/go-winio v0.4.12 github.com/Microsoft/go-winio v0.4.12
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576 github.com/pkg/errors v0.8.1 // indirect
golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53 golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734
golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54 golang.org/x/net v0.0.0-20190502183928-7f726cade0ab
golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82
) )
replace github.com/Microsoft/go-winio => golang.zx2c4.com/wireguard/windows v0.0.0-20190429060359-b01600290cd4

20
go.sum
View File

@@ -1,11 +1,15 @@
github.com/Microsoft/go-winio v0.4.12 h1:xAfWHN1IrQ0NJ9TBC0KBZoqLjzDTr1ML+4MywiUOryc= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/Microsoft/go-winio v0.4.12/go.mod h1:VhR8bwka0BXejwEJY73c50VrPtXAaKcyvVC4A4RozmA= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576 h1:aUX/1G2gFSs4AsJJg2cL3HuoRhCSCz733FE5GUSuaT4= golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734 h1:p/H982KKEjUnLJkM3tt/LemDnOc1GiZL5FCVlORJ5zo=
golang.org/x/crypto v0.0.0-20190320223903-b7391e95e576/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190426145343-a29dc8fdc734/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53 h1:kcXqo9vE6fsZY5X5Rd7R1l7fTgnWaDCVmln65REefiE= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190320064053-1272bf9dcd53/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190502183928-7f726cade0ab h1:9RfW3ktsOZxgo9YNbBAjq1FWzc/igwEcUzZz8IXgSbk=
golang.org/x/net v0.0.0-20190502183928-7f726cade0ab/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54 h1:xe1/2UUJRmA9iDglQSlkx8c5n3twv58+K0mPpC2zmhA= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190321052220-f7bb7a8bee54/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82 h1:vsphBvatvfbhlb4PO1BYSr9dzugGxJ/SQHoNufZJq1w=
golang.org/x/sys v0.0.0-20190502175342-a43fa875dd82/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.zx2c4.com/wireguard/windows v0.0.0-20190429060359-b01600290cd4 h1:wueYNew2pMLl/LcKqX4PAzc+zV4suK9+DJaZ8yIEHkM=
golang.zx2c4.com/wireguard/windows v0.0.0-20190429060359-b01600290cd4/go.mod h1:Y+FYqVFaQO6a+1uigm0N0GiuaZrLEaBxEiJ8tfH9sMQ=

View File

@@ -10,11 +10,12 @@ package ipc
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"path" "path"
"unsafe" "unsafe"
"golang.org/x/sys/unix"
) )
var socketDirectory = "/var/run/wireguard" var socketDirectory = "/var/run/wireguard"

View File

@@ -8,11 +8,12 @@ package ipc
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"os" "os"
"path" "path"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
var socketDirectory = "/var/run/wireguard" var socketDirectory = "/var/run/wireguard"

View File

@@ -6,8 +6,9 @@
package ipc package ipc
import ( import (
"github.com/Microsoft/go-winio"
"net" "net"
"github.com/Microsoft/go-winio"
) )
//TODO: replace these with actual standard windows error numbers from the win package //TODO: replace these with actual standard windows error numbers from the win package

37
main.go
View File

@@ -9,14 +9,15 @@ package main
import ( import (
"fmt" "fmt"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
"os" "os"
"os/signal" "os/signal"
"runtime" "runtime"
"strconv" "strconv"
"syscall" "syscall"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun"
) )
const ( const (
@@ -36,38 +37,28 @@ func printUsage() {
} }
func warning() { func warning() {
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return return
} }
shouldQuit := os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1"
shouldQuit := false
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W This is alpha software. It will very likely not G") fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
fmt.Fprintln(os.Stderr, "W do what it is supposed to do, and things may go G") fmt.Fprintln(os.Stderr, "W which is probably unnecessary and foolish. This G")
fmt.Fprintln(os.Stderr, "W horribly wrong. You have been warned. Proceed G") fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
fmt.Fprintln(os.Stderr, "W at your own risk. G") fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
if runtime.GOOS == "linux" { fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
shouldQuit = os.Getenv("WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD") != "1" fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W Furthermore, you are running this software on a G")
fmt.Fprintln(os.Stderr, "W Linux kernel, which is probably unnecessary and G")
fmt.Fprintln(os.Stderr, "W foolish. This is because the Linux kernel has G")
fmt.Fprintln(os.Stderr, "W built-in first class support for WireGuard, and G")
fmt.Fprintln(os.Stderr, "W this support is much more refined than this G")
fmt.Fprintln(os.Stderr, "W program. For more information on installing the G")
fmt.Fprintln(os.Stderr, "W kernel module, please visit: G")
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G") fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
if shouldQuit { if shouldQuit {
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G") fmt.Fprintln(os.Stderr, "W If you still want to use this program, against G")
fmt.Fprintln(os.Stderr, "W the sage advice here, please first export this G") fmt.Fprintln(os.Stderr, "W the advice here, please first export this G")
fmt.Fprintln(os.Stderr, "W environment variable: G") fmt.Fprintln(os.Stderr, "W environment variable: G")
fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G") fmt.Fprintln(os.Stderr, "W WG_I_PREFER_BUGGY_USERSPACE_TO_POLISHED_KMOD=1 G")
} }
}
fmt.Fprintln(os.Stderr, "W G") fmt.Fprintln(os.Stderr, "W G")
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING") fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")

View File

@@ -7,12 +7,13 @@ package main
import ( import (
"fmt" "fmt"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"os" "os"
"os/signal" "os/signal"
"syscall" "syscall"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/ipc"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -27,11 +28,13 @@ func main() {
} }
interfaceName := os.Args[1] interfaceName := os.Args[1]
fmt.Fprintln(os.Stderr, "Warning: this is a test program for Windows, mainly used for debugging this Go package. For a real WireGuard for Windows client, the repo you want is <https://git.zx2c4.com/wireguard-windows/>, which includes this code as a module.")
logger := device.NewLogger( logger := device.NewLogger(
device.LogLevelDebug, device.LogLevelDebug,
fmt.Sprintf("(%s) ", interfaceName), fmt.Sprintf("(%s) ", interfaceName),
) )
logger.Info.Println("Starting wireguard-go version", WireGuardGoVersion) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName) tun, err := tun.CreateTUN(interfaceName)

View File

@@ -7,9 +7,10 @@ package rwcancel
import ( import (
"errors" "errors"
"golang.org/x/sys/unix"
"os" "os"
"syscall" "syscall"
"golang.org/x/sys/unix"
) )
func max(a, b int) int { func max(a, b int) int {

View File

@@ -8,9 +8,10 @@ package tun
import ( import (
"bytes" "bytes"
"errors" "errors"
"golang.zx2c4.com/wireguard/tun"
"os" "os"
"testing" "testing"
"golang.zx2c4.com/wireguard/tun"
) )
/* Helpers for writing unit tests /* Helpers for writing unit tests

View File

@@ -7,13 +7,14 @@ package tun
import ( import (
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
const utunControlName = "com.apple.net.utun_control" const utunControlName = "com.apple.net.utun_control"
@@ -47,7 +48,10 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
statusMTU int statusMTU int
) )
defer close(tun.events) defer func() {
close(tun.events)
tun.routeSocket = -1
}()
data := make([]byte, os.Getpagesize()) data := make([]byte, os.Getpagesize())
for { for {
@@ -292,7 +296,6 @@ func (tun *NativeTun) Close() error {
if tun.routeSocket != -1 { if tun.routeSocket != -1 {
unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR) unix.Shutdown(tun.routeSocket, unix.SHUT_RDWR)
err2 = unix.Close(tun.routeSocket) err2 = unix.Close(tun.routeSocket)
tun.routeSocket = -1
} else if tun.events != nil { } else if tun.events != nil {
close(tun.events) close(tun.events)
} }

View File

@@ -9,19 +9,30 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"net" "net"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h // _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96)) // const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const _TUNSIFHEAD = 0x80047460 const (
const _TUNSIFMODE = 0x8004745e _TUNSIFHEAD = 0x80047460
const _TUNSIFPID = 0x2000745f _TUNSIFMODE = 0x8004745e
_TUNSIFPID = 0x2000745f
)
//TODO: move into x/sys/unix
const (
SIOCGIFINFO_IN6 = 0xc048696c
SIOCSIFINFO_IN6 = 0xc048696d
ND6_IFF_AUTO_LINKLOCAL = 0x20
ND6_IFF_NO_DAD = 0x100
)
// Iface status string max len // Iface status string max len
const _IFSTATMAX = 800 const _IFSTATMAX = 800
@@ -32,7 +43,7 @@ const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
type ifreq_ptr struct { type ifreq_ptr struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
Pad0 [24 - SIZEOF_UINTPTR]byte Pad0 [16 - SIZEOF_UINTPTR]byte
} }
// Structure for iface mtu get/set ioctls // Structure for iface mtu get/set ioctls
@@ -48,6 +59,23 @@ type ifstat struct {
Ascii [_IFSTATMAX]byte Ascii [_IFSTATMAX]byte
} }
// Structures for nd6 flag manipulation
type in6_ndireq struct {
Name [unix.IFNAMSIZ]byte
Linkmtu uint32
Maxmtu uint32
Basereachable uint32
Reachable uint32
Retrans uint32
Flags uint32
Recalctm int
Chlim uint8
Initialized uint8
Randomseed0 [8]byte
Randomseed1 [8]byte
Randomid [8]byte
}
type NativeTun struct { type NativeTun struct {
name string name string
tunFile *os.File tunFile *os.File
@@ -191,23 +219,18 @@ func tunName(fd uintptr) (string, error) {
// Destroy a named system interface // Destroy a named system interface
func tunDestroy(name string) error { func tunDestroy(name string) error {
// open control socket // Open control socket.
var fd int var fd int
fd, err := unix.Socket( fd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
0, 0,
) )
if err != nil { if err != nil {
return err return err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call
var ifr [32]byte var ifr [32]byte
copy(ifr[:], name) copy(ifr[:], name)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(
@@ -216,7 +239,6 @@ func tunDestroy(name string) error {
uintptr(unix.SIOCIFDESTROY), uintptr(unix.SIOCIFDESTROY),
uintptr(unsafe.Pointer(&ifr[0])), uintptr(unsafe.Pointer(&ifr[0])),
) )
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error())
} }
@@ -263,33 +285,71 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
}) })
if errno != 0 { if errno != 0 {
return nil, fmt.Errorf("error %s", errno.Error()) tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
} }
// Rename tun interface // Open control sockets
// Open control socket
confd, err := unix.Socket( confd, err := unix.Socket(
unix.AF_INET, unix.AF_INET,
unix.SOCK_DGRAM, unix.SOCK_DGRAM,
0, 0,
) )
if err != nil { if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err return nil, err
} }
defer unix.Close(confd) defer unix.Close(confd)
confd6, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd6)
// set up struct for iface rename // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
var ndireq in6_ndireq
copy(ndireq.Name[:], assignedName)
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCGIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
}
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
_, _, errno = unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCSIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 {
tunFile.Close()
tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
}
// Rename the interface
var newnp [unix.IFNAMSIZ]byte var newnp [unix.IFNAMSIZ]byte
copy(newnp[:], name) copy(newnp[:], name)
var ifr ifreq_ptr var ifr ifreq_ptr
copy(ifr.Name[:], assignedName) copy(ifr.Name[:], assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
//do actual ioctl to rename iface
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(
unix.SYS_IOCTL, unix.SYS_IOCTL,
uintptr(confd), uintptr(confd),
@@ -298,8 +358,8 @@ func CreateTUN(name string, mtu int) (TUNDevice, error) {
) )
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(name) tunDestroy(assignedName)
return nil, fmt.Errorf("failed to rename %s to %s: %s", assignedName, name, errno.Error()) return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
} }
return CreateTUNFromFile(tunFile, mtu) return CreateTUNFromFile(tunFile, mtu)

View File

@@ -12,15 +12,16 @@ import (
"bytes" "bytes"
"errors" "errors"
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net" "net"
"os" "os"
"sync" "sync"
"syscall" "syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
) )
const ( const (

View File

@@ -7,13 +7,14 @@ package tun
import ( import (
"fmt" "fmt"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"io/ioutil" "io/ioutil"
"net" "net"
"os" "os"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
) )
// Structure for iface mtu get/set ioctls // Structure for iface mtu get/set ioctls

View File

@@ -7,9 +7,9 @@ package tun
import ( import (
"errors" "errors"
"fmt"
"os" "os"
"sync" "sync"
"syscall"
"time" "time"
"unsafe" "unsafe"
@@ -48,7 +48,7 @@ type NativeTun struct {
wrBuff *exchgBufWrite wrBuff *exchgBufWrite
events chan TUNEvent events chan TUNEvent
errors chan error errors chan error
forcedMtu int forcedMTU int
} }
func packetAlign(size uint32) uint32 { func packetAlign(size uint32) uint32 {
@@ -60,32 +60,35 @@ func packetAlign(size uint32) uint32 {
// adapter with the same name exist, it is reused. // adapter with the same name exist, it is reused.
// //
func CreateTUN(ifname string) (TUNDevice, error) { func CreateTUN(ifname string) (TUNDevice, error) {
var err error
var wt *wintun.Wintun
// Does an interface with this name already exist? // Does an interface with this name already exist?
wt, err := wintun.GetInterface(ifname, 0) wt, err = wintun.GetInterface(ifname, 0)
if wt == nil { if wt == nil {
// Interface does not exist or an error occured. Create one. // Interface does not exist or an error occurred. Create one.
wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", 0) wt, _, err = wintun.CreateInterface("WireGuard Tunnel Adapter", 0)
if err != nil { if err != nil {
return nil, errors.New("Creating Wintun adapter failed: " + err.Error()) return nil, fmt.Errorf("wintun.CreateInterface: %v", err)
} }
} else if err != nil { } else if err != nil {
// Foreign interface with the same name found. // Foreign interface with the same name found.
// We could create a Wintun interface under a temporary name. But, should our // We could create a Wintun interface under a temporary name. But, should our
// proces die without deleting this interface first, the interface would remain // process die without deleting this interface first, the interface would remain
// orphaned. // orphaned.
return nil, err return nil, fmt.Errorf("wintun.GetInterface: %v", err)
} }
err = wt.SetInterfaceName(ifname) err = wt.SetInterfaceName(ifname)
if err != nil { if err != nil {
wt.DeleteInterface(0) wt.DeleteInterface(0)
return nil, err return nil, fmt.Errorf("wintun.SetInterfaceName: %v", err)
} }
err = wt.FlushInterface() err = wt.FlushInterface()
if err != nil { if err != nil {
wt.DeleteInterface(0) wt.DeleteInterface(0)
return nil, errors.New("Flushing interface failed: " + err.Error()) return nil, fmt.Errorf("wintun.FlushInterface: %v", err)
} }
return &NativeTun{ return &NativeTun{
@@ -94,7 +97,7 @@ func CreateTUN(ifname string) (TUNDevice, error) {
wrBuff: &exchgBufWrite{}, wrBuff: &exchgBufWrite{},
events: make(chan TUNEvent, 10), events: make(chan TUNEvent, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
forcedMtu: 1500, forcedMTU: 1500,
}, nil }, nil
} }
@@ -218,12 +221,12 @@ func (tun *NativeTun) Close() error {
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *NativeTun) MTU() (int, error) {
return tun.forcedMtu, nil return tun.forcedMTU, nil
} }
//TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes. //TODO: This is a temporary hack. We really need to be monitoring the interface in real time and adapting to MTU changes.
func (tun *NativeTun) ForceMtu(mtu int) { func (tun *NativeTun) ForceMTU(mtu int) {
tun.forcedMtu = mtu tun.forcedMTU = mtu
} }
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) { func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
@@ -273,7 +276,7 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
retries-- retries--
continue continue
} }
if ok && pe.Err == syscall.Errno(6) /*windows.ERROR_INVALID_HANDLE*/ { if ok && pe.Err == windows.ERROR_HANDLE_EOF {
tun.closeTUN() tun.closeTUN()
break break
} }
@@ -301,7 +304,7 @@ func (tun *NativeTun) Flush() error {
} }
// Flush write buffer. // Flush write buffer.
retries := retryTimeout * retryRate retries := 1000
for { for {
_, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset]) _, err = file.Write(tun.wrBuff.data[:tun.wrBuff.offset])
tun.wrBuff.packetNum = 0 tun.wrBuff.packetNum = 0
@@ -312,11 +315,10 @@ func (tun *NativeTun) Flush() error {
return os.ErrClosed return os.ErrClosed
} }
if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED { if retries > 0 && ok && pe.Err == windows.ERROR_OPERATION_ABORTED {
time.Sleep(time.Second / retryRate)
retries-- retries--
continue continue
} }
if ok && pe.Err == syscall.Errno(6) /*windows.ERROR_INVALID_HANDLE*/ { if ok && pe.Err == windows.ERROR_HANDLE_EOF {
tun.closeTUN() tun.closeTUN()
break break
} }
@@ -371,3 +373,10 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
func (tun *NativeTun) GUID() windows.GUID { func (tun *NativeTun) GUID() windows.GUID {
return tun.wt.CfgInstanceID return tun.wt.CfgInstanceID
} }
//
// GUID returns Windows adapter instance ID.
//
func (tun *NativeTun) LUID() uint64 {
return ((uint64(tun.wt.LUIDIndex) & ((1 << 24) - 1)) << 24) | ((uint64(tun.wt.IfType) & ((1 << 16) - 1)) << 48)
}

View File

@@ -12,7 +12,7 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
) )
//sys clsidFromString(lpsz *uint16, pclsid *windows.GUID) (hr int32) = ole32.CLSIDFromString //sys clsidFromString(lpsz *uint16, pclsid *windows.GUID) (err error) [failretval!=0] = ole32.CLSIDFromString
// //
// FromString parses "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}" string to GUID. // FromString parses "{XXXXXX-XXXX-XXXX-XXXX-XXXXXXXXXXXX}" string to GUID.
@@ -22,14 +22,11 @@ func FromString(str string) (*windows.GUID, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
guid := &windows.GUID{} guid := &windows.GUID{}
err = clsidFromString(strUTF16, guid)
hr := clsidFromString(strUTF16, guid) if err != nil {
if hr < 0 { return nil, err
return nil, syscall.Errno(hr)
} }
return guid, nil return guid, nil
} }
@@ -40,5 +37,5 @@ func FromString(str string) (*windows.GUID, error) {
// The resulting string is uppercase. // The resulting string is uppercase.
// //
func ToString(guid *windows.GUID) string { func ToString(guid *windows.GUID) string {
return fmt.Sprintf("{%06X-%04X-%04X-%04X-%012X}", guid.Data1, guid.Data2, guid.Data3, guid.Data4[:2], guid.Data4[2:]) return fmt.Sprintf("{%08X-%04X-%04X-%04X-%012X}", guid.Data1, guid.Data2, guid.Data3, guid.Data4[:2], guid.Data4[2:])
} }

View File

@@ -42,8 +42,14 @@ var (
procCLSIDFromString = modole32.NewProc("CLSIDFromString") procCLSIDFromString = modole32.NewProc("CLSIDFromString")
) )
func clsidFromString(lpsz *uint16, pclsid *windows.GUID) (hr int32) { func clsidFromString(lpsz *uint16, pclsid *windows.GUID) (err error) {
r0, _, _ := syscall.Syscall(procCLSIDFromString.Addr(), 2, uintptr(unsafe.Pointer(lpsz)), uintptr(unsafe.Pointer(pclsid)), 0) r1, _, e1 := syscall.Syscall(procCLSIDFromString.Addr(), 2, uintptr(unsafe.Pointer(lpsz)), uintptr(unsafe.Pointer(pclsid)), 0)
hr = int32(r0) if r1 != 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return return
} }

View File

@@ -0,0 +1,8 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package registry
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go

View File

@@ -0,0 +1,272 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package registry
import (
"errors"
"fmt"
"runtime"
"strings"
"time"
"unsafe"
"golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry"
)
const (
// REG_NOTIFY_CHANGE_NAME notifies the caller if a subkey is added or deleted.
REG_NOTIFY_CHANGE_NAME uint32 = 0x00000001
// REG_NOTIFY_CHANGE_ATTRIBUTES notifies the caller of changes to the attributes of the key, such as the security descriptor information.
REG_NOTIFY_CHANGE_ATTRIBUTES uint32 = 0x00000002
// REG_NOTIFY_CHANGE_LAST_SET notifies the caller of changes to a value of the key. This can include adding or deleting a value, or changing an existing value.
REG_NOTIFY_CHANGE_LAST_SET uint32 = 0x00000004
// REG_NOTIFY_CHANGE_SECURITY notifies the caller of changes to the security descriptor of the key.
REG_NOTIFY_CHANGE_SECURITY uint32 = 0x00000008
// REG_NOTIFY_THREAD_AGNOSTIC indicates that the lifetime of the registration must not be tied to the lifetime of the thread issuing the RegNotifyChangeKeyValue call. Note: This flag value is only supported in Windows 8 and later.
REG_NOTIFY_THREAD_AGNOSTIC uint32 = 0x10000000
)
//sys regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) = advapi32.RegNotifyChangeKeyValue
func OpenKeyWait(k registry.Key, path string, access uint32, timeout time.Duration) (registry.Key, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
deadline := time.Now().Add(timeout)
pathSpl := strings.Split(path, "\\")
for i := 0; ; i++ {
keyName := pathSpl[i]
isLast := i+1 == len(pathSpl)
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return 0, fmt.Errorf("Error creating event: %v", err)
}
defer windows.CloseHandle(event)
var key registry.Key
for {
err = regNotifyChangeKeyValue(windows.Handle(k), false, REG_NOTIFY_CHANGE_NAME, windows.Handle(event), true)
if err != nil {
return 0, fmt.Errorf("Setting up change notification on registry key failed: %v", err)
}
var accessFlags uint32
if isLast {
accessFlags = access
} else {
accessFlags = registry.NOTIFY
}
key, err = registry.OpenKey(k, keyName, accessFlags)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return 0, fmt.Errorf("Unable to wait on registry key: %v", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return 0, errors.New("Timeout waiting for registry key")
}
} else if err != nil {
return 0, fmt.Errorf("Error opening registry key %v: %v", path, err)
} else {
if isLast {
return key, nil
}
defer key.Close()
break
}
}
k = key
}
}
func WaitForKey(k registry.Key, path string, timeout time.Duration) error {
key, err := OpenKeyWait(k, path, registry.NOTIFY, timeout)
if err != nil {
return err
}
key.Close()
return nil
}
//
// getValue is more or less the same as windows/registry's getValue.
//
func getValue(k registry.Key, name string, buf []byte) (value []byte, valueType uint32, err error) {
var name16 *uint16
name16, err = windows.UTF16PtrFromString(name)
if err != nil {
return
}
n := uint32(len(buf))
for {
err = windows.RegQueryValueEx(windows.Handle(k), name16, nil, &valueType, (*byte)(unsafe.Pointer(&buf[0])), &n)
if err == nil {
value = buf[:n]
return
}
if err != windows.ERROR_MORE_DATA {
return
}
if n <= uint32(len(buf)) {
return
}
buf = make([]byte, n)
}
}
//
// getValueRetry function reads any value from registry. It waits for
// the registry value to become available or returns error on timeout.
//
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
//
func getValueRetry(key registry.Key, name string, buf []byte, timeout time.Duration) ([]byte, uint32, error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
event, err := windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return nil, 0, fmt.Errorf("Error creating event: %v", err)
}
defer windows.CloseHandle(event)
deadline := time.Now().Add(timeout)
for {
err := regNotifyChangeKeyValue(windows.Handle(key), false, REG_NOTIFY_CHANGE_LAST_SET, windows.Handle(event), true)
if err != nil {
return nil, 0, fmt.Errorf("Setting up change notification on registry value failed: %v", err)
}
buf, valueType, err := getValue(key, name, buf)
if err == windows.ERROR_FILE_NOT_FOUND || err == windows.ERROR_PATH_NOT_FOUND {
timeout := time.Until(deadline) / time.Millisecond
if timeout < 0 {
timeout = 0
}
s, err := windows.WaitForSingleObject(event, uint32(timeout))
if err != nil {
return nil, 0, fmt.Errorf("Unable to wait on registry value: %v", err)
}
if s == uint32(windows.WAIT_TIMEOUT) { // windows.WAIT_TIMEOUT status const is misclassified as error in golang.org/x/sys/windows
return nil, 0, errors.New("Timeout waiting for registry value")
}
} else if err != nil {
return nil, 0, fmt.Errorf("Error reading registry value %v: %v", name, err)
} else {
return buf, valueType, nil
}
}
}
func toString(buf []byte, valueType uint32, err error) (string, error) {
if err != nil {
return "", err
}
var value string
switch valueType {
case registry.SZ, registry.EXPAND_SZ, registry.MULTI_SZ:
if len(buf) == 0 {
return "", nil
}
value = windows.UTF16ToString((*[(1 << 30) - 1]uint16)(unsafe.Pointer(&buf[0]))[:len(buf)/2])
default:
return "", registry.ErrUnexpectedType
}
if valueType != registry.EXPAND_SZ {
// Value does not require expansion.
return value, nil
}
valueExp, err := registry.ExpandString(value)
if err != nil {
// Expanding failed: return original sting value.
return value, nil
}
// Return expanded value.
return valueExp, nil
}
func toInteger(buf []byte, valueType uint32, err error) (uint64, error) {
if err != nil {
return 0, err
}
switch valueType {
case registry.DWORD:
if len(buf) != 4 {
return 0, errors.New("DWORD value is not 4 bytes long")
}
var val uint32
copy((*[4]byte)(unsafe.Pointer(&val))[:], buf)
return uint64(val), nil
case registry.QWORD:
if len(buf) != 8 {
return 0, errors.New("QWORD value is not 8 bytes long")
}
var val uint64
copy((*[8]byte)(unsafe.Pointer(&val))[:], buf)
return val, nil
default:
return 0, registry.ErrUnexpectedType
}
}
//
// GetStringValueWait function reads a string value from registry. It waits
// for the registry value to become available or returns error on timeout.
//
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
//
// If the value type is REG_EXPAND_SZ the environment variables are expanded.
// Should expanding fail, original string value and nil error are returned.
//
// If the value type is REG_MULTI_SZ only the first string is returned.
//
func GetStringValueWait(key registry.Key, name string, timeout time.Duration) (string, error) {
return toString(getValueRetry(key, name, make([]byte, 256), timeout))
}
//
// GetStringValue function reads a string value from registry.
//
// Key must be opened with at least QUERY_VALUE access.
//
// If the value type is REG_EXPAND_SZ the environment variables are expanded.
// Should expanding fail, original string value and nil error are returned.
//
// If the value type is REG_MULTI_SZ only the first string is returned.
//
func GetStringValue(key registry.Key, name string) (string, error) {
return toString(getValue(key, name, make([]byte, 256)))
}
//
// GetIntegerValueWait function reads a DWORD32 or QWORD value from registry.
// It waits for the registry value to become available or returns error on
// timeout.
//
// Key must be opened with at least QUERY_VALUE|NOTIFY access.
//
func GetIntegerValueWait(key registry.Key, name string, timeout time.Duration) (uint64, error) {
return toInteger(getValueRetry(key, name, make([]byte, 8), timeout))
}

View File

@@ -0,0 +1,103 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package registry
import (
"testing"
"time"
"golang.org/x/sys/windows/registry"
)
const keyRoot = registry.CURRENT_USER
const pathRoot = "Software\\WireGuardRegistryTest"
const path = pathRoot + "\\foobar"
const pathFake = pathRoot + "\\raboof"
func Test_WaitForKey(t *testing.T) {
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
go func() {
time.Sleep(time.Second * 1)
key, _, err := registry.CreateKey(keyRoot, pathFake, registry.QUERY_VALUE)
if err != nil {
t.Errorf("Error creating registry key: %v", err)
}
key.Close()
registry.DeleteKey(keyRoot, pathFake)
key, _, err = registry.CreateKey(keyRoot, path, registry.QUERY_VALUE)
if err != nil {
t.Errorf("Error creating registry key: %v", err)
}
key.Close()
}()
err := WaitForKey(keyRoot, path, time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry key: %v", err)
}
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
err = WaitForKey(keyRoot, path, time.Second*1)
if err == nil {
t.Error("Registry key notification expected to timeout but it succeeded.")
}
}
func Test_GetValueWait(t *testing.T) {
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
go func() {
time.Sleep(time.Second * 1)
key, _, err := registry.CreateKey(keyRoot, path, registry.SET_VALUE)
if err != nil {
t.Errorf("Error creating registry key: %v", err)
}
time.Sleep(time.Second * 1)
key.SetStringValue("name1", "eulav")
key.SetExpandStringValue("name2", "value")
time.Sleep(time.Second * 1)
key.SetDWordValue("name3", ^uint32(123))
key.SetDWordValue("name4", 123)
key.Close()
}()
key, err := OpenKeyWait(keyRoot, path, registry.QUERY_VALUE|registry.NOTIFY, time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry key: %v", err)
}
valueStr, err := GetStringValueWait(key, "name2", time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry value: %v", err)
}
if valueStr != "value" {
t.Errorf("Wrong value read: %v", valueStr)
}
_, err = GetStringValueWait(key, "nonexisting", time.Second*1)
if err == nil {
t.Error("Registry value notification expected to timeout but it succeeded.")
}
valueInt, err := GetIntegerValueWait(key, "name4", time.Second*2)
if err != nil {
t.Errorf("Error waiting for registry value: %v", err)
}
if valueInt != 123 {
t.Errorf("Wrong value read: %v", valueInt)
}
_, err = GetIntegerValueWait(key, "nonexisting", time.Second*1)
if err == nil {
t.Error("Registry value notification expected to timeout but it succeeded.")
}
key.Close()
registry.DeleteKey(keyRoot, path)
registry.DeleteKey(keyRoot, pathRoot)
}

View File

@@ -0,0 +1,63 @@
// Code generated by 'go generate'; DO NOT EDIT.
package registry
import (
"syscall"
"unsafe"
"golang.org/x/sys/windows"
)
var _ unsafe.Pointer
// Do the interface allocations only once for common
// Errno values.
const (
errnoERROR_IO_PENDING = 997
)
var (
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
)
// errnoErr returns common boxed Errno values, to prevent
// allocations at runtime.
func errnoErr(e syscall.Errno) error {
switch e {
case 0:
return nil
case errnoERROR_IO_PENDING:
return errERROR_IO_PENDING
}
// TODO: add more here, after collecting data on the common
// error values see on Windows. (perhaps when running
// all.bat?)
return e
}
var (
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
procRegNotifyChangeKeyValue = modadvapi32.NewProc("RegNotifyChangeKeyValue")
)
func regNotifyChangeKeyValue(key windows.Handle, watchSubtree bool, notifyFilter uint32, event windows.Handle, asynchronous bool) (regerrno error) {
var _p0 uint32
if watchSubtree {
_p0 = 1
} else {
_p0 = 0
}
var _p1 uint32
if asynchronous {
_p1 = 1
} else {
_p1 = 0
}
r0, _, _ := syscall.Syscall6(procRegNotifyChangeKeyValue.Addr(), 5, uintptr(key), uintptr(_p0), uintptr(notifyFilter), uintptr(event), uintptr(_p1), 0)
if r0 != 0 {
regerrno = syscall.Errno(r0)
}
return
}

View File

@@ -1,42 +0,0 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package wintun
import (
"golang.org/x/sys/windows/registry"
"time"
)
const (
numRetries = 25
retryTimeout = 100 * time.Millisecond
)
func registryOpenKeyRetry(k registry.Key, path string, access uint32) (key registry.Key, err error) {
for i := 0; i < numRetries; i++ {
key, err = registry.OpenKey(k, path, access)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}
func keyGetStringValueRetry(k registry.Key, name string) (val string, valtype uint32, err error) {
for i := 0; i < numRetries; i++ {
val, valtype, err = k.GetStringValue(name)
if err == nil {
break
}
if i != numRetries - 1 {
time.Sleep(retryTimeout)
}
}
return
}

View File

@@ -8,6 +8,7 @@ package setupapi
import ( import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"runtime"
"syscall" "syscall"
"unsafe" "unsafe"
@@ -34,7 +35,7 @@ func SetupDiCreateDeviceInfoListEx(classGUID *windows.GUID, hwndParent uintptr,
// SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name. // SetupDiGetDeviceInfoListDetail function retrieves information associated with a device information set including the class GUID, remote computer handle, and remote computer name.
func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) { func SetupDiGetDeviceInfoListDetail(deviceInfoSet DevInfo) (deviceInfoSetDetailData *DevInfoListDetailData, err error) {
data := &DevInfoListDetailData{} data := &DevInfoListDetailData{}
data.size = uint32(unsafe.Sizeof(*data)) data.size = sizeofDevInfoListDetailData
return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data) return data, setupDiGetDeviceInfoListDetail(deviceInfoSet, data)
} }
@@ -155,7 +156,7 @@ func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoDa
var bufLen uint32 var bufLen uint32
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0])) data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
data.size = uint32(unsafe.Sizeof(*data)) data.size = sizeofDrvInfoDetailData
err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufCapacity, &bufLen) err := setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufCapacity, &bufLen)
if err == nil { if err == nil {
@@ -168,7 +169,7 @@ func SetupDiGetDriverInfoDetail(deviceInfoSet DevInfo, deviceInfoData *DevInfoDa
// The buffer was too small. Now that we got the required size, create another one big enough and retry. // The buffer was too small. Now that we got the required size, create another one big enough and retry.
buf := make([]byte, bufLen) buf := make([]byte, bufLen)
data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0])) data := (*DrvInfoDetailData)(unsafe.Pointer(&buf[0]))
data.size = uint32(unsafe.Sizeof(*data)) data.size = sizeofDrvInfoDetailData
err = setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufLen, &bufLen) err = setupDiGetDriverInfoDetail(deviceInfoSet, deviceInfoData, driverInfoData, data, bufLen, &bufLen)
if err == nil { if err == nil {
@@ -261,9 +262,13 @@ func SetupDiGetDeviceRegistryProperty(deviceInfoSet DevInfo, deviceInfoData *Dev
func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) { func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
switch dataType { switch dataType {
case windows.REG_SZ: case windows.REG_SZ:
return windows.UTF16ToString(BufToUTF16(buf)), nil ret := windows.UTF16ToString(bufToUTF16(buf))
runtime.KeepAlive(buf)
return ret, nil
case windows.REG_EXPAND_SZ: case windows.REG_EXPAND_SZ:
return registry.ExpandString(windows.UTF16ToString(BufToUTF16(buf))) ret, err := registry.ExpandString(windows.UTF16ToString(bufToUTF16(buf)))
runtime.KeepAlive(buf)
return ret, err
case windows.REG_BINARY: case windows.REG_BINARY:
return buf, nil return buf, nil
case windows.REG_DWORD_LITTLE_ENDIAN: case windows.REG_DWORD_LITTLE_ENDIAN:
@@ -271,7 +276,7 @@ func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
case windows.REG_DWORD_BIG_ENDIAN: case windows.REG_DWORD_BIG_ENDIAN:
return binary.BigEndian.Uint32(buf), nil return binary.BigEndian.Uint32(buf), nil
case windows.REG_MULTI_SZ: case windows.REG_MULTI_SZ:
bufW := BufToUTF16(buf) bufW := bufToUTF16(buf)
a := []string{} a := []string{}
for i := 0; i < len(bufW); { for i := 0; i < len(bufW); {
j := i + wcslen(bufW[i:]) j := i + wcslen(bufW[i:])
@@ -280,6 +285,7 @@ func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
} }
i = j + 1 i = j + 1
} }
runtime.KeepAlive(buf)
return a, nil return a, nil
case windows.REG_QWORD_LITTLE_ENDIAN: case windows.REG_QWORD_LITTLE_ENDIAN:
return binary.LittleEndian.Uint64(buf), nil return binary.LittleEndian.Uint64(buf), nil
@@ -288,8 +294,8 @@ func getRegistryValue(buf []byte, dataType uint32) (interface{}, error) {
} }
} }
// BufToUTF16 function reinterprets []byte buffer as []uint16 // bufToUTF16 function reinterprets []byte buffer as []uint16
func BufToUTF16(buf []byte) []uint16 { func bufToUTF16(buf []byte) []uint16 {
sl := struct { sl := struct {
addr *uint16 addr *uint16
len int len int
@@ -298,8 +304,8 @@ func BufToUTF16(buf []byte) []uint16 {
return *(*[]uint16)(unsafe.Pointer(&sl)) return *(*[]uint16)(unsafe.Pointer(&sl))
} }
// UTF16ToBuf function reinterprets []uint16 as []byte // utf16ToBuf function reinterprets []uint16 as []byte
func UTF16ToBuf(buf []uint16) []byte { func utf16ToBuf(buf []uint16) []byte {
sl := struct { sl := struct {
addr *byte addr *byte
len int len int
@@ -334,6 +340,16 @@ func (deviceInfoSet DevInfo) SetDeviceRegistryProperty(deviceInfoData *DevInfoDa
return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers) return SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, propertyBuffers)
} }
func (deviceInfoSet DevInfo) SetDeviceRegistryPropertyString(deviceInfoData *DevInfoData, property SPDRP, str string) error {
str16, err := windows.UTF16FromString(str)
if err != nil {
return err
}
err = SetupDiSetDeviceRegistryProperty(deviceInfoSet, deviceInfoData, property, utf16ToBuf(append(str16, 0)))
runtime.KeepAlive(str16)
return err
}
//sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW //sys setupDiGetDeviceInstallParams(deviceInfoSet DevInfo, deviceInfoData *DevInfoData, deviceInstallParams *DevInstallParams) (err error) = setupapi.SetupDiGetDeviceInstallParamsW
// SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element. // SetupDiGetDeviceInstallParams function retrieves device installation parameters for a device information set or a particular device information element.

View File

@@ -6,6 +6,7 @@
package setupapi package setupapi
import ( import (
"runtime"
"strings" "strings"
"syscall" "syscall"
"testing" "testing"
@@ -131,7 +132,7 @@ func TestSetupDiEnumDeviceInfo(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -153,7 +154,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i) deviceData, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -170,7 +171,7 @@ func TestDevInfo_BuildDriverInfoList(t *testing.T) {
for j := 0; true; j++ { for j := 0; true; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j) driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, j)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -264,7 +265,7 @@ func TestSetupDiGetClassDevsEx(t *testing.T) {
devInfoList.Close() devInfoList.Close()
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail") t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail")
} else { } else {
if errWin, ok := err.(syscall.Errno); !ok || errWin != 87 /*ERROR_INVALID_PARAMETER*/ { if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER") t.Errorf("SetupDiGetClassDevsEx(nil, ...) should fail with ERROR_INVALID_PARAMETER")
} }
} }
@@ -280,7 +281,7 @@ func TestSetupDiOpenDevRegKey(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -304,7 +305,7 @@ func TestSetupDiGetDeviceRegistryProperty(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -334,7 +335,7 @@ func TestSetupDiGetDeviceRegistryProperty(t *testing.T) {
val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS) val, err = devInfoList.GetDeviceRegistryProperty(data, SPDRP_COMPATIBLEIDS)
if err != nil { if err != nil {
// Some devices have no SPDRP_COMPATIBLEIDS. // Some devices have no SPDRP_COMPATIBLEIDS.
if errWin, ok := err.(syscall.Errno); !ok || errWin != 13 /*windows.ERROR_INVALID_DATA*/ { if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_DATA {
t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error()) t.Errorf("Error calling SetupDiGetDeviceRegistryProperty(SPDRP_COMPATIBLEIDS): %s", err.Error())
} }
} }
@@ -361,7 +362,7 @@ func TestSetupDiGetDeviceInstallParams(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -399,7 +400,7 @@ func TestSetupDiClassNameFromGuidEx(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail") t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail")
} else { } else {
if errWin, ok := err.(syscall.Errno); !ok || errWin != 1784 /*ERROR_INVALID_USER_BUFFER*/ { if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_USER_BUFFER {
t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER") t.Errorf("SetupDiClassNameFromGuidEx(nil) should fail with ERROR_INVALID_USER_BUFFER")
} }
} }
@@ -440,7 +441,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
for i := 0; true; i++ { for i := 0; true; i++ {
data, err := devInfoList.EnumDeviceInfo(i) data, err := devInfoList.EnumDeviceInfo(i)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
@@ -463,7 +464,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
if err == nil { if err == nil {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail") t.Errorf("SetupDiSetSelectedDevice(nil) should fail")
} else { } else {
if errWin, ok := err.(syscall.Errno); !ok || errWin != 87 /*ERROR_INVALID_PARAMETER*/ { if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_PARAMETER {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
} }
} }
@@ -471,7 +472,7 @@ func TestSetupDiGetSelectedDevice(t *testing.T) {
func TestUTF16ToBuf(t *testing.T) { func TestUTF16ToBuf(t *testing.T) {
buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef} buf := []uint16{0x0123, 0x4567, 0x89ab, 0xcdef}
buf2 := UTF16ToBuf(buf) buf2 := utf16ToBuf(buf)
if len(buf)*2 != len(buf2) || if len(buf)*2 != len(buf2) ||
cap(buf)*2 != cap(buf2) || cap(buf)*2 != cap(buf2) ||
buf2[0] != 0x23 || buf2[1] != 0x01 || buf2[0] != 0x23 || buf2[1] != 0x01 ||
@@ -480,4 +481,5 @@ func TestUTF16ToBuf(t *testing.T) {
buf2[6] != 0xef || buf2[7] != 0xcd { buf2[6] != 0xef || buf2[7] != 0xcd {
t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER") t.Errorf("SetupDiSetSelectedDevice(nil) should fail with ERROR_INVALID_USER_BUFFER")
} }
runtime.KeepAlive(buf)
} }

View File

@@ -28,6 +28,7 @@ const (
// Define maximum string length constants // Define maximum string length constants
// //
const ( const (
ANYSIZE_ARRAY = 1
LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF. LINE_LEN = 256 // Windows 9x-compatible maximum for displayable strings coming from a device INF.
MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions). MAX_INF_STRING_LENGTH = 4096 // Actual maximum size of an INF string (including string substitutions).
MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters. MAX_INF_SECTION_NAME_LENGTH = 255 // For Windows 9x compatibility, INF section names should be constrained to 32 characters.
@@ -59,7 +60,7 @@ type DevInfoData struct {
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass). // DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct { type DevInfoListDetailData struct {
size uint32 size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID ClassGUID windows.GUID
RemoteMachineHandle windows.Handle RemoteMachineHandle windows.Handle
remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16 remoteMachineName [SP_MAX_MACHINENAME_LENGTH]uint16
@@ -370,7 +371,7 @@ func (data *DrvInfoData) IsNewer(driverDate windows.Filetime, driverVersion uint
// DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure) // DrvInfoDetailData is driver information details structure (provides detailed information about a particular driver information structure)
type DrvInfoDetailData struct { type DrvInfoDetailData struct {
size uint32 // On input, this must be exactly the sizeof(DrvInfoDetailData). On output, we set this member to the actual size of structure data. size uint32 // Warning: unsafe.Sizeof(DrvInfoDetailData) > sizeof(SP_DRVINFO_DETAIL_DATA) when GOARCH == 386 => use sizeofDrvInfoDetailData const.
InfDate windows.Filetime InfDate windows.Filetime
compatIDsOffset uint32 compatIDsOffset uint32
compatIDsLength uint32 compatIDsLength uint32
@@ -378,7 +379,7 @@ type DrvInfoDetailData struct {
sectionName [LINE_LEN]uint16 sectionName [LINE_LEN]uint16
infFileName [windows.MAX_PATH]uint16 infFileName [windows.MAX_PATH]uint16
drvDescription [LINE_LEN]uint16 drvDescription [LINE_LEN]uint16
hardwareID [1]uint16 hardwareID [ANYSIZE_ARRAY]uint16
} }
func (data *DrvInfoDetailData) GetSectionName() string { func (data *DrvInfoDetailData) GetSectionName() string {

View File

@@ -0,0 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package setupapi
const (
sizeofDevInfoListDetailData uint32 = 550
sizeofDrvInfoDetailData uint32 = 1570
)

View File

@@ -0,0 +1,11 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
*/
package setupapi
const (
sizeofDevInfoListDetailData uint32 = 560
sizeofDrvInfoDetailData uint32 = 1584
)

View File

@@ -14,7 +14,7 @@ import (
func TestSetupDiDestroyDeviceInfoList(t *testing.T) { func TestSetupDiDestroyDeviceInfoList(t *testing.T) {
err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle)) err := SetupDiDestroyDeviceInfoList(DevInfo(windows.InvalidHandle))
if errWin, ok := err.(syscall.Errno); !ok || errWin != 6 /*ERROR_INVALID_HANDLE*/ { if errWin, ok := err.(syscall.Errno); !ok || errWin != windows.ERROR_INVALID_HANDLE {
t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE") t.Errorf("SetupDiDestroyDeviceInfoList(nil, ...) should fail with ERROR_INVALID_HANDLE")
} }
} }

View File

@@ -8,7 +8,6 @@ package wintun
import ( import (
"errors" "errors"
"fmt" "fmt"
"golang.zx2c4.com/wireguard/tun/wintun/netshell"
"strings" "strings"
"syscall" "syscall"
"time" "time"
@@ -17,6 +16,8 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.org/x/sys/windows/registry" "golang.org/x/sys/windows/registry"
"golang.zx2c4.com/wireguard/tun/wintun/guid" "golang.zx2c4.com/wireguard/tun/wintun/guid"
"golang.zx2c4.com/wireguard/tun/wintun/netshell"
registryEx "golang.zx2c4.com/wireguard/tun/wintun/registry"
"golang.zx2c4.com/wireguard/tun/wintun/setupapi" "golang.zx2c4.com/wireguard/tun/wintun/setupapi"
) )
@@ -34,28 +35,23 @@ var deviceClassNetGUID = windows.GUID{Data1: 0x4d36e972, Data2: 0xe325, Data3: 0
const hardwareID = "Wintun" const hardwareID = "Wintun"
const enumerator = "" const enumerator = ""
const machineName = "" const machineName = ""
const waitForRegistryTimeout = time.Second * 5
// //
// MakeWintun creates interface handle and populates it from device registry key // MakeWintun creates interface handle and populates it from device registry key
// //
func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (*Wintun, error) { func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (*Wintun, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.READ) key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil { if err != nil {
return nil, errors.New("Device-specific registry key open failed: " + err.Error()) return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
} }
defer key.Close() defer key.Close()
var valueStr string
var valueType uint32
// Read the NetCfgInstanceId value. // Read the NetCfgInstanceId value.
valueStr, valueType, err = keyGetStringValueRetry(key, "NetCfgInstanceId") valueStr, err := registryEx.GetStringValue(key, "NetCfgInstanceId")
if err != nil { if err != nil {
return nil, errors.New("RegQueryStringValue(\"NetCfgInstanceId\") failed: " + err.Error()) return nil, fmt.Errorf("RegQueryStringValue(\"NetCfgInstanceId\") failed: %v", err)
}
if valueType != registry.SZ {
return nil, fmt.Errorf("NetCfgInstanceId registry value is not REG_SZ (expected: %v, provided: %v)", registry.SZ, valueType)
} }
// Convert to windows.GUID. // Convert to windows.GUID.
@@ -65,15 +61,15 @@ func MakeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
} }
// Read the NetLuidIndex value. // Read the NetLuidIndex value.
luidIdx, valueType, err := key.GetIntegerValue("NetLuidIndex") luidIdx, _, err := key.GetIntegerValue("NetLuidIndex")
if err != nil { if err != nil {
return nil, errors.New("RegQueryValue(\"NetLuidIndex\") failed: " + err.Error()) return nil, fmt.Errorf("RegQueryValue(\"NetLuidIndex\") failed: %v", err)
} }
// Read the NetLuidIndex value. // Read the NetLuidIndex value.
ifType, valueType, err := key.GetIntegerValue("*IfType") ifType, _, err := key.GetIntegerValue("*IfType")
if err != nil { if err != nil {
return nil, errors.New("RegQueryValue(\"*IfType\") failed: " + err.Error()) return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
} }
return &Wintun{ return &Wintun{
@@ -99,7 +95,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName) devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName)
if err != nil { if err != nil {
return nil, errors.New(fmt.Sprintf("SetupDiGetClassDevsEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) return nil, fmt.Errorf("SetupDiGetClassDevsEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err)
} }
defer devInfoList.Close() defer devInfoList.Close()
@@ -114,18 +110,19 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// Get the device from the list. Should anything be wrong with this device, continue with next. // Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index) deviceData, err := devInfoList.EnumDeviceInfo(index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
// Get interface ID. // Get interface ID.
wintun, err := MakeWintun(devInfoList, deviceData) wintun, err := makeWintun(devInfoList, deviceData)
if err != nil { if err != nil {
continue continue
} }
//TODO: is there a better way than comparing ifnames?
// Get interface name. // Get interface name.
ifname2, err := wintun.GetInterfaceName() ifname2, err := wintun.GetInterfaceName()
if err != nil { if err != nil {
@@ -137,7 +134,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
const driverType = setupapi.SPDIT_COMPATDRIVER const driverType = setupapi.SPDIT_COMPATDRIVER
err = devInfoList.BuildDriverInfoList(deviceData, driverType) err = devInfoList.BuildDriverInfoList(deviceData, driverType)
if err != nil { if err != nil {
return nil, errors.New("SetupDiBuildDriverInfoList failed: " + err.Error()) return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
} }
defer devInfoList.DestroyDriverInfoList(deviceData, driverType) defer devInfoList.DestroyDriverInfoList(deviceData, driverType)
@@ -145,7 +142,7 @@ func GetInterface(ifname string, hwndParent uintptr) (*Wintun, error) {
// Get a driver from the list. // Get a driver from the list.
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index) driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
// Something is wrong with this driver. Skip it. // Something is wrong with this driver. Skip it.
@@ -191,52 +188,49 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
// Create an empty device info set for network adapter device class. // Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, hwndParent, machineName) devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, hwndParent, machineName)
if err != nil { if err != nil {
return nil, false, errors.New(fmt.Sprintf("SetupDiCreateDeviceInfoListEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) return nil, false, fmt.Errorf("SetupDiCreateDeviceInfoListEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err)
} }
defer devInfoList.Close()
// Get the device class name from GUID. // Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, machineName) className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, machineName)
if err != nil { if err != nil {
return nil, false, errors.New(fmt.Sprintf("SetupDiClassNameFromGuidEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) return nil, false, fmt.Errorf("SetupDiClassNameFromGuidEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err)
} }
// Create a new device info element and add it to the device info set. // Create a new device info element and add it to the device info set.
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, hwndParent, setupapi.DICD_GENERATE_ID) deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, description, hwndParent, setupapi.DICD_GENERATE_ID)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiCreateDeviceInfo failed: " + err.Error()) return nil, false, fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
} }
// Set a device information element as the selected member of a device information set. // Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData) err = devInfoList.SetSelectedDevice(deviceData)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiSetSelectedDevice failed: " + err.Error()) return nil, false, fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
} }
// Set Plug&Play device hardware ID property. // Set Plug&Play device hardware ID property.
hwid, err := syscall.UTF16FromString(hardwareID) err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil { if err != nil {
return nil, false, err // syscall.UTF16FromString(hardwareID) should never fail: hardwareID is const string without NUL chars. return nil, false, fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
}
err = devInfoList.SetDeviceRegistryProperty(deviceData, setupapi.SPDRP_HARDWAREID, setupapi.UTF16ToBuf(append(hwid, 0)))
if err != nil {
return nil, false, errors.New("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: " + err.Error())
} }
// Search for the driver. // Search for the driver.
const driverType = setupapi.SPDIT_CLASSDRIVER const driverType = setupapi.SPDIT_COMPATDRIVER
err = devInfoList.BuildDriverInfoList(deviceData, driverType) err = devInfoList.BuildDriverInfoList(deviceData, driverType) //TODO: This takes ~510ms
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiBuildDriverInfoList failed: " + err.Error()) return nil, false, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
} }
defer devInfoList.DestroyDriverInfoList(deviceData, driverType) defer devInfoList.DestroyDriverInfoList(deviceData, driverType)
driverDate := windows.Filetime{} driverDate := windows.Filetime{}
driverVersion := uint64(0) driverVersion := uint64(0)
for index := 0; ; index++ { for index := 0; ; index++ { //TODO: This loop takes ~600ms
// Get a driver from the list. // Get a driver from the list.
driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index) driverData, err := devInfoList.EnumDriverInfo(deviceData, driverType, index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
// Something is wrong with this driver. Skip it. // Something is wrong with this driver. Skip it.
@@ -273,7 +267,7 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData) err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
if err != nil { if err != nil {
return nil, false, errors.New("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: " + err.Error()) return nil, false, fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
} }
// Register device co-installers if any. (Ignore errors) // Register device co-installers if any. (Ignore errors)
@@ -282,37 +276,104 @@ func CreateInterface(description string, hwndParent uintptr) (*Wintun, bool, err
// Install interfaces if any. (Ignore errors) // Install interfaces if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData) devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
var wintun *Wintun
var rebootRequired bool
// Install the device. // Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData) err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
if err != nil { if err != nil {
err = errors.New("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: " + err.Error()) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
} }
var wintun *Wintun
var rebootRequired bool
var key registry.Key
if err == nil { if err == nil {
// Check if a system reboot is required. (Ignore errors) // Check if a system reboot is required. (Ignore errors)
if ret, _ := checkReboot(devInfoList, deviceData); ret { if ret, _ := checkReboot(devInfoList, deviceData); ret {
rebootRequired = true rebootRequired = true
} }
// Get network interface. DIF_INSTALLDEVICE returns almost immediately, while the device // DIF_INSTALLDEVICE returns almost immediately, while the device installation
// installation continues in the background. It might take a while, before all registry // continues in the background. It might take a while, before all registry
// keys and values are populated. // keys and values are populated.
for numAttempts := 0; numAttempts < 30; numAttempts++ { const pollTimeout = time.Millisecond * 50
wintun, err = MakeWintun(devInfoList, deviceData) for i := 0; i < int(waitForRegistryTimeout/pollTimeout); i++ {
if err != nil { if i != 0 {
if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_FILE_NOT_FOUND { time.Sleep(pollTimeout)
// Wait and retry. TODO: Wait for a cancellable event instead. }
err = errors.New("Time-out waiting for adapter to get ready") key, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE|registry.NOTIFY)
time.Sleep(time.Second / 4) if err == nil {
continue break
}
}
if err == nil {
_, err = registryEx.GetStringValueWait(key, "NetCfgInstanceId", waitForRegistryTimeout)
if err == nil {
_, err = registryEx.GetIntegerValueWait(key, "NetLuidIndex", waitForRegistryTimeout)
}
if err == nil {
_, err = registryEx.GetIntegerValueWait(key, "*IfType", waitForRegistryTimeout)
}
key.Close()
} }
} }
break if err == nil {
// Get network interface.
wintun, err = makeWintun(devInfoList, deviceData)
} }
if err == nil {
// Wait for network registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.GetNetRegKeyName(),
registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
if err == nil {
_, err = registryEx.GetStringValueWait(key, "Name", waitForRegistryTimeout)
key.Close()
}
}
if err == nil {
// Wait for TCP/IP adapter registry key to emerge and populate.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
wintun.GetTcpipAdapterRegKeyName(), registry.QUERY_VALUE|registry.NOTIFY,
waitForRegistryTimeout)
if err == nil {
_, err = registryEx.GetStringValueWait(key, "IpConfig", waitForRegistryTimeout)
key.Close()
}
}
var tcpipInterfaceRegKeyName string
if err == nil {
tcpipInterfaceRegKeyName, err = wintun.GetTcpipInterfaceRegKeyName()
if err == nil {
// Wait for TCP/IP interface registry key to emerge.
key, err = registryEx.OpenKeyWait(
registry.LOCAL_MACHINE,
tcpipInterfaceRegKeyName, registry.QUERY_VALUE,
waitForRegistryTimeout)
if err == nil {
key.Close()
}
}
}
//
// All the registry keys and values we're relying on are present now.
//
if err == nil {
// Disable dead gateway detection on our interface.
key, err = registry.OpenKey(registry.LOCAL_MACHINE, tcpipInterfaceRegKeyName, registry.SET_VALUE)
if err != nil {
err = fmt.Errorf("Error opening interface-specific TCP/IP network registry key: %v", err)
}
key.SetDWordValue("EnableDeadGWDetect", 0)
key.Close()
} }
if err == nil { if err == nil {
@@ -355,7 +416,7 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName) devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, enumerator, hwndParent, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), machineName)
if err != nil { if err != nil {
return false, false, errors.New(fmt.Sprintf("SetupDiGetClassDevsEx(%v) failed: ", guid.ToString(&deviceClassNetGUID)) + err.Error()) return false, false, fmt.Errorf("SetupDiGetClassDevsEx(%s) failed: %v", guid.ToString(&deviceClassNetGUID), err.Error())
} }
defer devInfoList.Close() defer devInfoList.Close()
@@ -364,14 +425,15 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
// Get the device from the list. Should anything be wrong with this device, continue with next. // Get the device from the list. Should anything be wrong with this device, continue with next.
deviceData, err := devInfoList.EnumDeviceInfo(index) deviceData, err := devInfoList.EnumDeviceInfo(index)
if err != nil { if err != nil {
if errWin, ok := err.(syscall.Errno); ok && errWin == 259 /*ERROR_NO_MORE_ITEMS*/ { if errWin, ok := err.(syscall.Errno); ok && errWin == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
// Get interface ID. // Get interface ID.
wintun2, err := MakeWintun(devInfoList, deviceData) //TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData)
if err != nil { if err != nil {
continue continue
} }
@@ -386,13 +448,13 @@ func (wintun *Wintun) DeleteInterface(hwndParent uintptr) (bool, bool, error) {
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
return false, false, errors.New("SetupDiSetClassInstallParams failed: " + err.Error()) return false, false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
} }
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
if err != nil { if err != nil {
return false, false, errors.New("SetupDiCallClassInstaller failed: " + err.Error()) return false, false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
} }
// Check if a system reboot is required. (Ignore errors) // Check if a system reboot is required. (Ignore errors)
@@ -436,34 +498,33 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
// GetInterfaceName returns network interface name. // GetInterfaceName returns network interface name.
// //
func (wintun *Wintun) GetInterfaceName() (string, error) { func (wintun *Wintun) GetInterfaceName() (string, error) {
key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE) key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.QUERY_VALUE)
if err != nil { if err != nil {
return "", errors.New("Network-specific registry key open failed: " + err.Error()) return "", fmt.Errorf("Network-specific registry key open failed: %v", err)
} }
defer key.Close() defer key.Close()
// Get the interface name. // Get the interface name.
return getRegStringValue(key, "Name") return registryEx.GetStringValue(key, "Name")
} }
// //
// SetInterfaceName sets network interface name. // SetInterfaceName sets network interface name.
// //
func (wintun *Wintun) SetInterfaceName(ifname string) error { func (wintun *Wintun) SetInterfaceName(ifname string) error {
// We open the registry key before calling HrRename, because the registry open will wait until the key exists.
key, err := registryOpenKeyRetry(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
if err != nil {
return errors.New("Network-specific registry key open failed: " + err.Error())
}
defer key.Close()
// We have to tell the various runtime COM services about the new name too. We ignore the // We have to tell the various runtime COM services about the new name too. We ignore the
// error because netshell isn't available on servercore. // error because netshell isn't available on servercore. It's also slow, so we run it in a
// separate thread.
// TODO: netsh.exe falls back to NciSetConnection in this case. If somebody complains, maybe // TODO: netsh.exe falls back to NciSetConnection in this case. If somebody complains, maybe
// we should do the same. // we should do the same.
_ = netshell.HrRenameConnection(&wintun.CfgInstanceID, windows.StringToUTF16Ptr(ifname)) go netshell.HrRenameConnection(&wintun.CfgInstanceID, windows.StringToUTF16Ptr(ifname))
// Set the interface name. The above line should have done this too, but in case it failed, we force it. // Set the interface name. The above line should have done this too, but in case it failed, we force it.
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetNetRegKeyName(), registry.SET_VALUE)
if err != nil {
return fmt.Errorf("Network-specific registry key open failed: %v", err)
}
defer key.Close()
return key.SetStringValue("Name", ifname) return key.SetStringValue("Name", ifname)
} }
@@ -471,35 +532,33 @@ func (wintun *Wintun) SetInterfaceName(ifname string) error {
// GetNetRegKeyName returns interface-specific network registry key name. // GetNetRegKeyName returns interface-specific network registry key name.
// //
func (wintun *Wintun) GetNetRegKeyName() string { func (wintun *Wintun) GetNetRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%v\\%v\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(&wintun.CfgInstanceID)) return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Control\\Network\\%s\\%s\\Connection", guid.ToString(&deviceClassNetGUID), guid.ToString(&wintun.CfgInstanceID))
} }
// //
// getRegStringValue function reads a string value from registry. // GetTcpipAdapterRegKeyName returns adapter-specific TCP/IP network registry key name.
// //
// If the value type is REG_EXPAND_SZ the environment variables are expanded. func (wintun *Wintun) GetTcpipAdapterRegKeyName() string {
// Should expanding fail, original string value and nil error are returned. return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%s", guid.ToString(&wintun.CfgInstanceID))
}
// //
func getRegStringValue(key registry.Key, name string) (string, error) { // GetTcpipInterfaceRegKeyName returns interface-specific TCP/IP network registry key name.
// Read string value. //
value, valueType, err := keyGetStringValueRetry(key, name) func (wintun *Wintun) GetTcpipInterfaceRegKeyName() (path string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.GetTcpipAdapterRegKeyName(), registry.QUERY_VALUE)
if err != nil { if err != nil {
return "", err return "", fmt.Errorf("Error opening adapter-specific TCP/IP network registry key: %v", err)
} }
paths, _, err := key.GetStringsValue("IpConfig")
if valueType != registry.EXPAND_SZ { key.Close()
// Value does not require expansion.
return value, nil
}
valueExp, err := registry.ExpandString(value)
if err != nil { if err != nil {
// Expanding failed: return original sting value. return "", fmt.Errorf("Error reading IpConfig registry key: %v", err)
return value, nil
} }
if len(paths) == 0 {
// Return expanded value. return "", errors.New("No TCP/IP interfaces found on adapter")
return valueExp, nil }
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
} }
// //