23 Commits

Author SHA1 Message Date
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
05b03c6750 version: bump snapshot 2020-01-21 16:27:19 +01:00
Jason A. Donenfeld
caebdfe9d0 tun: darwin: ignore ENOMEM errors
Coauthored-by: Andrej Mihajlov <and@mullvad.net>
2020-01-15 13:39:37 -05:00
Jason A. Donenfeld
4fa2ea6a2d tun: windows: serialize write calls 2020-01-07 11:40:45 -05:00
Jason A. Donenfeld
89dd065e53 README: update repo urls 2019-12-30 11:53:39 +01:00
Jason A. Donenfeld
ddfad453cf device: SendmsgN mutates the input sockaddr
So we take a new granular lock to prevent concurrent writes from
racing.

WARNING: DATA RACE
Write at 0x00c0011f2740 by goroutine 27:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:318
+0x4b8

Previous write at 0x00c0011f2740 by goroutine 386:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.expiredRetransmitHandshake()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:110
+0x40c
  golang.zx2c4.com/wireguard/device.(*Peer).NewTimer.func1()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:42
+0xd8

Goroutine 27 (running) created at:
  golang.zx2c4.com/wireguard/device.NewDevice()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/device.go:322
+0x5e8
  main.main()
      /go/src/x/main.go:102 +0x58e

Goroutine 386 (finished) created at:
  time.goFunc()
      /usr/local/go/src/time/sleep.go:168 +0x51

Reported-by: Ben Burkert <ben@benburkert.com>
2019-11-28 11:11:13 +01:00
Jason A. Donenfeld
2b242f9393 wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially
well.
2019-11-22 13:13:55 +01:00
Jason A. Donenfeld
4cdf805b29 constants: recalculate rekey max based on a one minute flood
Discussed-with: Mathias Hall-Andersen <mathias@hall-andersen.dk>
2019-10-30 14:29:32 +01:00
Jonathan Tooker
f7d0edd2ec global: fix a few typos courtesy of codespell
Signed-off-by: Jonathan Tooker <jonathan.tooker@netprotect.com>
2019-10-22 11:51:25 +02:00
Jason A. Donenfeld
ffffbbcc8a device: allow blackholing sockets 2019-10-21 13:29:57 +02:00
Jason A. Donenfeld
47b02c618b device: remove dead error reporting code 2019-10-21 11:46:54 +02:00
Jason A. Donenfeld
fd23c66fcd namespaceapi: remove tasteless comment 2019-10-21 09:02:29 +02:00
Jason A. Donenfeld
ae492d1b35 device: recheck counters while holding write lock 2019-10-17 15:43:06 +02:00
Jason A. Donenfeld
95fbfccf60 wintun: normalize variable names for their types 2019-10-17 15:30:56 +02:00
Avery Pennarun
c85e4a410f wintun: quickly ignore non-Wintun devices
Some devices take ~2 seconds to enumerate on Windows if we try to get
their instance name.  The hardware id property, on the other hand,
is available right away.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: inlined this to where it makes sense, reused setupapi const]
2019-10-17 15:19:20 +02:00
Avery Pennarun
1b6c8ddbe8 tun: match windows CreateTUN signature to the Linux variant
Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: fix default value]
2019-10-17 15:19:20 +02:00
Avery Pennarun
0abb6b668c rwcancel: handle EINTR and EAGAIN in unixSelect()
On my Chromebook (Linux 4.19.44 in a VM) and on an AWS EC2
machine, select() was sometimes returning EINTR. This is
harmless and just means you should try again. So let's try
again.

This eliminates a problem where the tunnel fails to come up
correctly and the program needs to be restarted.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
2019-10-17 15:19:17 +02:00
David Crawshaw
540d01e54a device: test packets between two fake devices
Signed-off-by: David Crawshaw <crawshaw@tailscale.io>
2019-10-16 11:38:28 +02:00
25 changed files with 498 additions and 236 deletions

View File

@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0
```
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.

View File

@@ -18,7 +18,7 @@ const (
sockoptIPV6_UNICAST_IF = 31
)
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex)
@@ -41,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
if err != nil {
return err
}
device.net.bind.(*nativeBind).blackhole4 = blackhole
return nil
}
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil {
return err
@@ -58,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
if err != nil {
return err
}
device.net.bind.(*nativeBind).blackhole6 = blackhole
return nil
}

View File

@@ -23,6 +23,8 @@ import (
type nativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
}
type NativeEndpoint net.UDPAddr
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole4 {
return nil
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err

View File

@@ -7,7 +7,7 @@
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
@@ -43,6 +43,7 @@ type IPv6Source struct {
}
type NativeEndpoint struct {
sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
@@ -117,7 +118,7 @@ func createNetlinkRouteSocket() (int, error) {
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
Groups: unix.RTMGRP_IPV4_ROUTE,
}
err = unix.Bind(sock, saddr)
if err != nil {
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful
// attempt ipv6 bind, update port if successful
bind.sock6, newPort, err = create6(port)
if err != nil {
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
port = newPort
}
// attempt ipv4 bind, update port if succesful
// attempt ipv4 bind, update port if successful
bind.sock4, newPort, err = create4(port)
if err != nil {
@@ -482,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
},
}
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
if err == nil {
return nil
@@ -493,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
}
return err
@@ -522,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0
}
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
if err == nil {
return nil
@@ -533,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
}
return err
@@ -541,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
@@ -573,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
// construct message header
var cmsg struct {
cmsghdr unix.Cmsghdr

View File

@@ -12,7 +12,7 @@ import (
/* Specification constants */
const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90

View File

@@ -236,24 +236,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap {
for _, peer := range device.peers.keyMap {
handshake := &peer.handshake
if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
}
}
for _, peer := range lockedPeers {
peer.handshake.mutex.RUnlock()

View File

@@ -5,54 +5,212 @@
package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import (
"bufio"
"bytes"
"encoding/binary"
"io"
"net"
"os"
"strings"
"testing"
"time"
"golang.zx2c4.com/wireguard/tun"
)
func TestDevice(t *testing.T) {
// prepare tun devices for generating traffic
tun1 := newDummyTUN("tun1")
tun2 := newDummyTUN("tun2")
_ = tun1
_ = tun2
// prepare endpoints
end1, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
func TestTwoDevicePing(t *testing.T) {
// TODO(crawshaw): pick unused ports on localhost
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
listen_port=53511
replace_peers=true
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
tun1 := NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
dev1.Up()
defer dev1.Close()
if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
listen_port=53512
replace_peers=true
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
t.Fatal(err)
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("ping did not transit")
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("return ping did not transit")
}
})
}
func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
}
func assertNil(t *testing.T, err error) {
@@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) {
t.Fatal(a, "!=", b)
}
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View File

@@ -39,13 +39,13 @@ const (
)
const (
MessageInitiationSize = 148 // size of handshake initation message
MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
)
const (
@@ -154,6 +154,7 @@ func init() {
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
@@ -176,31 +172,22 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err
}
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) {
return nil, errZeroECDHResult
}
var key [chacha20poly1305.KeySize]byte
KDF2(
&handshake.chainKey,
&key,
@@ -209,23 +196,29 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
// encrypt timestamp
timestamp := tai64n.Now()
func() {
var key [chacha20poly1305.KeySize]byte
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
timestamp := tai64n.Now()
aead, _ = chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
// assign index
device.indexTable.Delete(handshake.localIndex)
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil {
return nil
}
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
}
handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity
var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
@@ -315,8 +309,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
if timestamp.After(handshake.lastTimestamp) {
handshake.lastTimestamp = timestamp
handshake.lastInitiationConsumption = time.Now()
}
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock()

View File

@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return
}
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:])
}

View File

@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk
handshake.mutex.Unlock()
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil
// conditionally add
// add
if !ssIsZero {
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer

View File

@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil {
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
return nil
}
func (peer *Peer) keepKeyFreshSending() {
@@ -518,11 +515,19 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu
var paddedSize int
if mtu == 0 {
paddedSize = (len(elem.packet) + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
} else {
lastUnit := len(elem.packet)
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
}
for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0)
}

View File

@@ -138,7 +138,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromHex(value)
err := sk.FromMaybeZeroHex(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}

View File

@@ -1,3 +1,3 @@
package device
const WireGuardGoVersion = "0.0.20191012"
const WireGuardGoVersion = "0.0.20200320"

2
go.mod
View File

@@ -5,6 +5,6 @@ go 1.12
require (
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
golang.org/x/text v0.3.2
)

4
go.sum
View File

@@ -6,8 +6,8 @@ golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqP
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=

View File

@@ -37,7 +37,7 @@ func main() {
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName)
tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil {
realInterfaceName, err2 := tun.Name()
if err2 == nil {

View File

@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{
allowed: true,
text: "inital burst",
text: "initial burst",
})
}

View File

@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
fdset := fdSet{}
fdset.set(rw.fd)
fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil {
return false
}
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
fdset := fdSet{}
fdset.set(rw.fd)
fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil {
return false
}

View File

@@ -11,6 +11,7 @@ import (
"net"
"os"
"syscall"
"time"
"unsafe"
"golang.org/x/net/ipv6"
@@ -42,6 +43,22 @@ type NativeTun struct {
var sockaddrCtlSize uintptr = 32
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
}
}
return iface, err
}
return nil, err
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var (
statusUp bool
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue
}
iface, err := net.InterfaceByIndex(ifindex)
iface, err := retryInterfaceByIndex(ifindex)
if err != nil {
tun.errors <- err
return

View File

@@ -35,7 +35,7 @@ type NativeTun struct {
name string // name of interface
errors chan error // async error handling
events chan Event // device related events
nopi bool // the device was pased IFF_NO_PI
nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
@@ -85,7 +85,7 @@ func createNetlinkSocket() (int, error) {
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
}
err = unix.Bind(sock, saddr)
if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"errors"
"fmt"
"os"
"sync"
"sync/atomic"
"time"
"unsafe"
@@ -35,11 +36,12 @@ type NativeTun struct {
wt *wintun.Interface
handle windows.Handle
close bool
rings wintun.RingDescriptor
events chan Event
errors chan error
forcedMTU int
rate rateJuggler
rings *wintun.RingDescriptor
writeLock sync.Mutex
}
const WintunPool = wintun.Pool("WireGuard")
@@ -54,15 +56,15 @@ func nanotime() int64
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused.
//
func CreateTUN(ifname string) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil)
func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil, mtu)
}
//
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
//
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) {
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error
var wt *wintun.Interface
@@ -80,21 +82,26 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
return nil, fmt.Errorf("Error creating interface: %v", err)
}
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &NativeTun{
wt: wt,
handle: windows.InvalidHandle,
events: make(chan Event, 10),
errors: make(chan error, 1),
forcedMTU: 1500,
forcedMTU: forcedMTU,
}
err = tun.rings.Init()
tun.rings, err = wintun.NewRingDescriptor()
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err)
}
tun.handle, err = tun.wt.Register(&tun.rings)
tun.handle, err = tun.wt.Register(tun.rings)
if err != nil {
tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -214,6 +221,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
tun.writeLock.Lock()
defer tun.writeLock.Unlock()
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed

View File

@@ -40,7 +40,7 @@ func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
}
// PrivateNamespace represents a private namespace. Duh?!
// PrivateNamespace represents a private namespace.
type PrivateNamespace windows.Handle
// CreatePrivateNamespace creates a private namespace.

View File

@@ -6,6 +6,7 @@
package wintun
import (
"runtime"
"unsafe"
"golang.org/x/sys/windows"
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
}
func (descriptor *RingDescriptor) Init() (err error) {
func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
descriptor = new(RingDescriptor)
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return
}
defer func() {
if err != nil {
descriptor.free()
descriptor = nil
}
}()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Send.Ring = &Ring{}
descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
return
}
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Receive.Ring = &Ring{}
descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved)
return
}
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return
}
func (descriptor *RingDescriptor) free() {
if descriptor.Send.Ring != nil {
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
descriptor.Send.Ring = nil
descriptor.Receive.Ring = nil
}
}
func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved)

View File

@@ -57,7 +57,7 @@ type DevInfoData struct {
_ uintptr
}
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct {
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID

View File

@@ -40,9 +40,9 @@ const (
)
// makeWintun creates a Wintun interface handle and populates it from the device's registry key.
func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil {
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
}
@@ -72,7 +72,7 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
}
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData)
instanceID, err := devInfo.DeviceInstanceID(devInfoData)
if err != nil {
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
}
@@ -109,11 +109,11 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
}()
// Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
}
defer devInfoList.Close()
defer devInfo.Close()
// Windows requires each interface to have a different name. When
// enforcing this, Windows treats interface names case-insensitive. If an
@@ -123,7 +123,7 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
ifname = strings.ToLower(ifname)
for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index)
devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -131,7 +131,16 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
continue
}
wintun, err := makeWintun(devInfoList, deviceData, pool)
// Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil {
continue
}
@@ -145,14 +154,14 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
ifname3 := removeNumberedSuffix(ifname2)
if ifname == ifname2 || ifname == ifname3 {
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
for index := 0; ; index++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -161,13 +170,13 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
}
// Get driver info details.
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil {
continue
}
if driverDetailData.IsCompatible(hardwareID) {
isMember, err := pool.isMember(devInfoList, deviceData)
isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil {
return nil, err
}
@@ -206,12 +215,12 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}()
// Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
return
}
defer devInfoList.Close()
defer devInfo.Close()
// Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
@@ -222,43 +231,43 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// Create a new device info element and add it to the device info set.
deviceTypeName := pool.deviceTypeName()
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
return
}
err = setQuietInstall(devInfoList, deviceData)
err = setQuietInstall(devInfo, devInfoData)
if err != nil {
err = fmt.Errorf("Setting quiet installation failed: %v", err)
return
}
// Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData)
err = devInfo.SetSelectedDevice(devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
return
}
// Set Plug&Play device hardware ID property.
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID)
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil {
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
return
}
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
if err != nil {
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
return
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
driverDate := windows.Filetime{}
driverVersion := uint64(0)
for index := 0; ; index++ { // TODO: This loop takes ~600ms
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -268,13 +277,13 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
if driverData.IsNewer(driverDate, driverVersion) {
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil {
continue
}
if driverDetailData.IsCompatible(hardwareID) {
err := devInfoList.SetSelectedDriver(deviceData, driverData)
err := devInfo.SetSelectedDriver(devInfoData, driverData)
if err != nil {
continue
}
@@ -299,10 +308,10 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}
// Set class installer parameters for DIF_REMOVE.
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
// Call appropriate class installer.
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
}
}
@@ -311,14 +320,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}()
// Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return
}
// Register device co-installers if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData)
devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
var netDevRegKey registry.Key
const pollTimeout = time.Millisecond * 50
@@ -326,7 +335,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
if i != 0 {
time.Sleep(pollTimeout)
}
netDevRegKey, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil {
break
}
@@ -345,17 +354,17 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}
// Install interfaces if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
// Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
return
}
rebootRequired = checkReboot(devInfoList, deviceData)
rebootRequired = checkReboot(devInfo, devInfoData)
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
if err != nil {
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
return
@@ -381,7 +390,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}
// Get network interface.
wintun, err = makeWintun(devInfoList, deviceData, pool)
wintun, err = makeWintun(devInfo, devInfoData, pool)
if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err)
return
@@ -435,14 +444,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// if the interface was not found. It returns a bool indicating whether
// a reboot is required.
func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
devInfoList, deviceData, err := wintun.deviceData()
devInfo, devInfoData, err := wintun.devInfoData()
if err == windows.ERROR_OBJECT_NOT_FOUND {
return false, nil
}
if err != nil {
return false, err
}
defer devInfoList.Close()
defer devInfo.Close()
// Remove the device.
removeDeviceParams := setupapi.RemoveDeviceParams{
@@ -451,18 +460,18 @@ func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
}
// Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil {
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
}
// Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil {
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
}
return checkReboot(devInfoList, deviceData), nil
return checkReboot(devInfo, devInfoData), nil
}
// DeleteMatchingInterfaces deletes all Wintun interfaces, which match
@@ -479,14 +488,14 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
windows.CloseHandle(mutex)
}()
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
}
defer devInfoList.Close()
defer devInfo.Close()
for i := 0; ; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i)
devInfoData, err := devInfo.EnumDeviceInfo(i)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -494,22 +503,31 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue
}
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
// Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
continue
}
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
isWintun := false
for j := 0; ; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j)
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
}
continue
}
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil {
continue
}
@@ -522,7 +540,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue
}
isMember, err := pool.isMember(devInfoList, deviceData)
isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil {
errors = append(errors, err)
continue
@@ -531,7 +549,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue
}
wintun, err := makeWintun(devInfoList, deviceData, pool)
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil {
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
continue
@@ -540,41 +558,41 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue
}
err = setQuietInstall(devInfoList, deviceData)
err = setQuietInstall(devInfo, devInfoData)
if err != nil {
errors = append(errors, err)
continue
}
inst := deviceData.DevInst
inst := devInfoData.DevInst
removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
}
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil {
errors = append(errors, err)
continue
}
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil {
errors = append(errors, err)
continue
}
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
}
return
}
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (bool, error) {
deviceDescVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_DEVICEDESC)
func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
}
deviceDesc, _ := deviceDescVal.(string)
friendlyNameVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_FRIENDLYNAME)
friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
}
@@ -585,8 +603,8 @@ func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupa
}
// checkReboot checks device install parameters if a system reboot is required.
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil {
return false
}
@@ -595,14 +613,14 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
}
// setQuietInstall sets device install parameters for a quiet installation
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil {
return err
}
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams)
return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
}
// deviceTypeName returns pool-specific device type name.
@@ -721,18 +739,18 @@ func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
}
// deviceData returns TUN device info list handle and interface device info
// devInfoData returns TUN device info list handle and interface device info
// data. The device info list handle must be closed after use. In case the
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
// Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil {
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
}
for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index)
devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS {
break
@@ -742,22 +760,22 @@ func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData,
// Get interface ID.
// TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData, wintun.pool)
wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
if err != nil {
continue
}
if wintun.cfgInstanceID == wintun2.cfgInstanceID {
err = setQuietInstall(devInfoList, deviceData)
err = setQuietInstall(devInfo, devInfoData)
if err != nil {
devInfoList.Close()
devInfo.Close()
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
}
return devInfoList, deviceData, nil
return devInfo, devInfoData, nil
}
}
devInfoList.Close()
devInfo.Close()
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
}