23 Commits

Author SHA1 Message Date
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
05b03c6750 version: bump snapshot 2020-01-21 16:27:19 +01:00
Jason A. Donenfeld
caebdfe9d0 tun: darwin: ignore ENOMEM errors
Coauthored-by: Andrej Mihajlov <and@mullvad.net>
2020-01-15 13:39:37 -05:00
Jason A. Donenfeld
4fa2ea6a2d tun: windows: serialize write calls 2020-01-07 11:40:45 -05:00
Jason A. Donenfeld
89dd065e53 README: update repo urls 2019-12-30 11:53:39 +01:00
Jason A. Donenfeld
ddfad453cf device: SendmsgN mutates the input sockaddr
So we take a new granular lock to prevent concurrent writes from
racing.

WARNING: DATA RACE
Write at 0x00c0011f2740 by goroutine 27:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:318
+0x4b8

Previous write at 0x00c0011f2740 by goroutine 386:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.expiredRetransmitHandshake()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:110
+0x40c
  golang.zx2c4.com/wireguard/device.(*Peer).NewTimer.func1()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:42
+0xd8

Goroutine 27 (running) created at:
  golang.zx2c4.com/wireguard/device.NewDevice()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/device.go:322
+0x5e8
  main.main()
      /go/src/x/main.go:102 +0x58e

Goroutine 386 (finished) created at:
  time.goFunc()
      /usr/local/go/src/time/sleep.go:168 +0x51

Reported-by: Ben Burkert <ben@benburkert.com>
2019-11-28 11:11:13 +01:00
Jason A. Donenfeld
2b242f9393 wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially
well.
2019-11-22 13:13:55 +01:00
Jason A. Donenfeld
4cdf805b29 constants: recalculate rekey max based on a one minute flood
Discussed-with: Mathias Hall-Andersen <mathias@hall-andersen.dk>
2019-10-30 14:29:32 +01:00
Jonathan Tooker
f7d0edd2ec global: fix a few typos courtesy of codespell
Signed-off-by: Jonathan Tooker <jonathan.tooker@netprotect.com>
2019-10-22 11:51:25 +02:00
Jason A. Donenfeld
ffffbbcc8a device: allow blackholing sockets 2019-10-21 13:29:57 +02:00
Jason A. Donenfeld
47b02c618b device: remove dead error reporting code 2019-10-21 11:46:54 +02:00
Jason A. Donenfeld
fd23c66fcd namespaceapi: remove tasteless comment 2019-10-21 09:02:29 +02:00
Jason A. Donenfeld
ae492d1b35 device: recheck counters while holding write lock 2019-10-17 15:43:06 +02:00
Jason A. Donenfeld
95fbfccf60 wintun: normalize variable names for their types 2019-10-17 15:30:56 +02:00
Avery Pennarun
c85e4a410f wintun: quickly ignore non-Wintun devices
Some devices take ~2 seconds to enumerate on Windows if we try to get
their instance name.  The hardware id property, on the other hand,
is available right away.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: inlined this to where it makes sense, reused setupapi const]
2019-10-17 15:19:20 +02:00
Avery Pennarun
1b6c8ddbe8 tun: match windows CreateTUN signature to the Linux variant
Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: fix default value]
2019-10-17 15:19:20 +02:00
Avery Pennarun
0abb6b668c rwcancel: handle EINTR and EAGAIN in unixSelect()
On my Chromebook (Linux 4.19.44 in a VM) and on an AWS EC2
machine, select() was sometimes returning EINTR. This is
harmless and just means you should try again. So let's try
again.

This eliminates a problem where the tunnel fails to come up
correctly and the program needs to be restarted.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
2019-10-17 15:19:17 +02:00
David Crawshaw
540d01e54a device: test packets between two fake devices
Signed-off-by: David Crawshaw <crawshaw@tailscale.io>
2019-10-16 11:38:28 +02:00
25 changed files with 498 additions and 236 deletions

View File

@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0 $ wireguard-go -f wg0
``` ```
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`. To run with more logging you may set the environment variable `LOG_LEVEL=debug`.

View File

@@ -18,7 +18,7 @@ const (
sockoptIPV6_UNICAST_IF = 31 sockoptIPV6_UNICAST_IF = 31
) )
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error { func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4) bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex) binary.BigEndian.PutUint32(bytes, interfaceIndex)
@@ -41,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole4 = blackhole
return nil return nil
} }
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error { func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
@@ -58,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole6 = blackhole
return nil return nil
} }

View File

@@ -23,6 +23,8 @@ import (
type nativeBind struct { type nativeBind struct {
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
} }
type NativeEndpoint net.UDPAddr type NativeEndpoint net.UDPAddr
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
if bind.ipv4 == nil { if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole4 {
return nil
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else { } else {
if bind.ipv6 == nil { if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} }
return err return err

View File

@@ -7,7 +7,7 @@
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port * WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code: * of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
* *
* Currently there is no way to achieve this within the net package: * Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
@@ -43,6 +43,7 @@ type IPv6Source struct {
} }
type NativeEndpoint struct { type NativeEndpoint struct {
sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool isV6 bool
@@ -117,7 +118,7 @@ func createNetlinkRouteSocket() (int, error) {
} }
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), Groups: unix.RTMGRP_IPV4_ROUTE,
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
go bind.routineRouteListener(device) go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful // attempt ipv6 bind, update port if successful
bind.sock6, newPort, err = create6(port) bind.sock6, newPort, err = create6(port)
if err != nil { if err != nil {
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
port = newPort port = newPort
} }
// attempt ipv4 bind, update port if succesful // attempt ipv4 bind, update port if successful
bind.sock4, newPort, err = create4(port) bind.sock4, newPort, err = create4(port)
if err != nil { if err != nil {
@@ -482,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
}, },
} }
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -493,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{} cmsg.pktinfo = unix.Inet4Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
} }
return err return err
@@ -522,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0 cmsg.pktinfo.Ifindex = 0
} }
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -533,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{} cmsg.pktinfo = unix.Inet6Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
} }
return err return err
@@ -541,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // construct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
@@ -573,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // construct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr

View File

@@ -12,7 +12,7 @@ import (
/* Specification constants */ /* Specification constants */
const ( const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90

View File

@@ -236,24 +236,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations // do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer) expiredPeers = append(expiredPeers, peer)
} }
}
for _, peer := range lockedPeers { for _, peer := range lockedPeers {
peer.handshake.mutex.RUnlock() peer.handshake.mutex.RUnlock()

View File

@@ -5,54 +5,212 @@
package device package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary"
"io"
"net"
"os"
"strings"
"testing" "testing"
"time"
"golang.zx2c4.com/wireguard/tun"
) )
func TestDevice(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
// TODO(crawshaw): pick unused ports on localhost
// prepare tun devices for generating traffic cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
listen_port=53511
tun1 := newDummyTUN("tun1") replace_peers=true
tun2 := newDummyTUN("tun2") public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
protocol_version=1
_ = tun1 replace_allowed_ips=true
_ = tun2 allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
// prepare endpoints tun1 := NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
end1, err := CreateDummyEndpoint() dev1.Up()
if err != nil { defer dev1.Close()
t.Error("failed to create endpoint:", err.Error()) if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "") cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
device := NewDevice(tun, logger) listen_port=53512
device.SetPrivateKey(sk) replace_peers=true
return device public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
t.Fatal(err)
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("ping did not transit")
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("return ping did not transit")
}
})
}
func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
} }
func assertNil(t *testing.T, err error) { func assertNil(t *testing.T, err error) {
@@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) {
t.Fatal(a, "!=", b) t.Fatal(a, "!=", b)
} }
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View File

@@ -39,13 +39,13 @@ const (
) )
const ( const (
MessageInitiationSize = 148 // size of handshake initation message MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
) )
const ( const (
@@ -154,6 +154,7 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key // create ephemeral key
var err error var err error
handshake.hash = InitialHash handshake.hash = InitialHash
handshake.chainKey = InitialChainKey handshake.chainKey = InitialChainKey
@@ -176,31 +172,22 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err return nil, err
} }
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
} }
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
if isZero(ss[:]) {
return nil, errZeroECDHResult
}
var key [chacha20poly1305.KeySize]byte
KDF2( KDF2(
&handshake.chainKey, &handshake.chainKey,
&key, &key,
@@ -209,23 +196,29 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
timestamp := tai64n.Now() return nil, errZeroECDHResult
func() { }
var key [chacha20poly1305.KeySize]byte
KDF2( KDF2(
&handshake.chainKey, &handshake.chainKey,
&key, &key,
handshake.chainKey[:], handshake.chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
aead, _ := chacha20poly1305.New(key[:]) timestamp := tai64n.Now()
aead, _ = chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
// assign index
device.indexTable.Delete(handshake.localIndex)
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.localIndex = msg.Sender
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
if isZero(ss[:]) {
return nil
}
KDF2(&chainKey, &key, chainKey[:], ss[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:]) aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) _, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil { if err != nil {
return nil return nil
} }
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
} }
handshake := &peer.handshake handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity // verify identity
var timestamp tai64n.Timestamp var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock() handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2( KDF2(
&chainKey, &chainKey,
&key, &key,
chainKey[:], chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil { if err != nil {
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
@@ -315,8 +309,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
if timestamp.After(handshake.lastTimestamp) {
handshake.lastTimestamp = timestamp handshake.lastTimestamp = timestamp
handshake.lastInitiationConsumption = time.Now() }
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = HandshakeInitiationConsumed handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()

View File

@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return return
} }
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string { func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:]) return hex.EncodeToString(key[:])
} }

View File

@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil peer.endpoint = nil
// conditionally add // add
if !ssIsZero {
device.peers.keyMap[pk] = peer device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer // start peer

View File

@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
writer := bytes.NewBuffer(buff[:0]) writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil { return nil
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
} }
func (peer *Peer) keepKeyFreshSending() { func (peer *Peer) keepKeyFreshSending() {
@@ -518,11 +515,19 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16 // pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu)) mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu var paddedSize int
if mtu == 0 {
paddedSize = (len(elem.packet) + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
} else {
lastUnit := len(elem.packet)
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu { if paddedSize > mtu {
paddedSize = mtu paddedSize = mtu
} }
}
for i := len(elem.packet); i < paddedSize; i++ { for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0) elem.packet = append(elem.packet, 0)
} }

View File

@@ -138,7 +138,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
err := sk.FromHex(value) err := sk.FromMaybeZeroHex(value)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid} return &IPCError{ipc.IpcErrorInvalid}

View File

@@ -1,3 +1,3 @@
package device package device
const WireGuardGoVersion = "0.0.20191012" const WireGuardGoVersion = "0.0.20200320"

2
go.mod
View File

@@ -5,6 +5,6 @@ go 1.12
require ( require (
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
golang.org/x/text v0.3.2 golang.org/x/text v0.3.2
) )

4
go.sum
View File

@@ -6,8 +6,8 @@ golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqP
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=

View File

@@ -37,7 +37,7 @@ func main() {
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName) tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tun.Name()
if err2 == nil { if err2 == nil {

View File

@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
for i := 0; i < packetsBurstable; i++ { for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{ Add(RatelimiterResult{
allowed: true, allowed: true,
text: "inital burst", text: "initial burst",
}) })
} }

View File

@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil { if err != nil {
return false return false
} }
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil { if err != nil {
return false return false
} }

View File

@@ -11,6 +11,7 @@ import (
"net" "net"
"os" "os"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@@ -42,6 +43,22 @@ type NativeTun struct {
var sockaddrCtlSize uintptr = 32 var sockaddrCtlSize uintptr = 32
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
}
}
return iface, err
}
return nil, err
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var ( var (
statusUp bool statusUp bool
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue continue
} }
iface, err := net.InterfaceByIndex(ifindex) iface, err := retryInterfaceByIndex(ifindex)
if err != nil { if err != nil {
tun.errors <- err tun.errors <- err
return return

View File

@@ -35,7 +35,7 @@ type NativeTun struct {
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan Event // device related events events chan Event // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was passed IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
@@ -85,7 +85,7 @@ func createNetlinkSocket() (int, error) {
} }
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))), Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
@@ -35,11 +36,12 @@ type NativeTun struct {
wt *wintun.Interface wt *wintun.Interface
handle windows.Handle handle windows.Handle
close bool close bool
rings wintun.RingDescriptor
events chan Event events chan Event
errors chan error errors chan error
forcedMTU int forcedMTU int
rate rateJuggler rate rateJuggler
rings *wintun.RingDescriptor
writeLock sync.Mutex
} }
const WintunPool = wintun.Pool("WireGuard") const WintunPool = wintun.Pool("WireGuard")
@@ -54,15 +56,15 @@ func nanotime() int64
// CreateTUN creates a Wintun interface with the given name. Should a Wintun // CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused. // interface with the same name exist, it is reused.
// //
func CreateTUN(ifname string) (Device, error) { func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil) return CreateTUNWithRequestedGUID(ifname, nil, mtu)
} }
// //
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
// //
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error var err error
var wt *wintun.Interface var wt *wintun.Interface
@@ -80,21 +82,26 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
return nil, fmt.Errorf("Error creating interface: %v", err) return nil, fmt.Errorf("Error creating interface: %v", err)
} }
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
}
tun := &NativeTun{ tun := &NativeTun{
wt: wt, wt: wt,
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
forcedMTU: 1500, forcedMTU: forcedMTU,
} }
err = tun.rings.Init() tun.rings, err = wintun.NewRingDescriptor()
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err) return nil, fmt.Errorf("Error creating events: %v", err)
} }
tun.handle, err = tun.wt.Register(&tun.rings) tun.handle, err = tun.wt.Register(tun.rings)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err) return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -214,6 +221,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.rate.update(uint64(packetSize)) tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
tun.writeLock.Lock()
defer tun.writeLock.Unlock()
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity { if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed return 0, os.ErrClosed

View File

@@ -40,7 +40,7 @@ func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid) return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
} }
// PrivateNamespace represents a private namespace. Duh?! // PrivateNamespace represents a private namespace.
type PrivateNamespace windows.Handle type PrivateNamespace windows.Handle
// CreatePrivateNamespace creates a private namespace. // CreatePrivateNamespace creates a private namespace.

View File

@@ -6,6 +6,7 @@
package wintun package wintun
import ( import (
"runtime"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
} }
func (descriptor *RingDescriptor) Init() (err error) { func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
descriptor = new(RingDescriptor)
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return
}
defer func() {
if err != nil {
descriptor.free()
descriptor = nil
}
}()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Send.Ring = &Ring{} descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
return return
} }
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Receive.Ring = &Ring{} descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved) windows.CloseHandle(descriptor.Send.TailMoved)
return return
} }
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return return
} }
func (descriptor *RingDescriptor) free() {
if descriptor.Send.Ring != nil {
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
descriptor.Send.Ring = nil
descriptor.Receive.Ring = nil
}
}
func (descriptor *RingDescriptor) Close() { func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 { if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved) windows.CloseHandle(descriptor.Send.TailMoved)

View File

@@ -57,7 +57,7 @@ type DevInfoData struct {
_ uintptr _ uintptr
} }
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass). // DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct { type DevInfoListDetailData struct {
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const. size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID ClassGUID windows.GUID

View File

@@ -40,9 +40,9 @@ const (
) )
// makeWintun creates a Wintun interface handle and populates it from the device's registry key. // makeWintun creates a Wintun interface handle and populates it from the device's registry key.
func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) { func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE) key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil { if err != nil {
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err) return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
} }
@@ -72,7 +72,7 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err) return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
} }
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData) instanceID, err := devInfo.DeviceInstanceID(devInfoData)
if err != nil { if err != nil {
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err) return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
} }
@@ -109,11 +109,11 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
}() }()
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err) return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
} }
defer devInfoList.Close() defer devInfo.Close()
// Windows requires each interface to have a different name. When // Windows requires each interface to have a different name. When
// enforcing this, Windows treats interface names case-insensitive. If an // enforcing this, Windows treats interface names case-insensitive. If an
@@ -123,7 +123,7 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
ifname = strings.ToLower(ifname) ifname = strings.ToLower(ifname)
for index := 0; ; index++ { for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index) devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -131,7 +131,16 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
continue continue
} }
wintun, err := makeWintun(devInfoList, deviceData, pool) // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
continue continue
} }
@@ -145,14 +154,14 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
ifname3 := removeNumberedSuffix(ifname2) ifname3 := removeNumberedSuffix(ifname2)
if ifname == ifname2 || ifname == ifname3 { if ifname == ifname2 || ifname == ifname3 {
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil { if err != nil {
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
for index := 0; ; index++ { for index := 0; ; index++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -161,13 +170,13 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
} }
// Get driver info details. // Get driver info details.
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
isMember, err := pool.isMember(devInfoList, deviceData) isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -206,12 +215,12 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}() }()
// Create an empty device info set for network adapter device class. // Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err) err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
return return
} }
defer devInfoList.Close() defer devInfo.Close()
// Get the device class name from GUID. // Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
@@ -222,43 +231,43 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// Create a new device info element and add it to the device info set. // Create a new device info element and add it to the device info set.
deviceTypeName := pool.deviceTypeName() deviceTypeName := pool.deviceTypeName()
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID) devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err) err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
return return
} }
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("Setting quiet installation failed: %v", err) err = fmt.Errorf("Setting quiet installation failed: %v", err)
return return
} }
// Set a device information element as the selected member of a device information set. // Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData) err = devInfo.SetSelectedDevice(devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err) err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
return return
} }
// Set Plug&Play device hardware ID property. // Set Plug&Play device hardware ID property.
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID) err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err) err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
return return
} }
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
return return
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
driverDate := windows.Filetime{} driverDate := windows.Filetime{}
driverVersion := uint64(0) driverVersion := uint64(0)
for index := 0; ; index++ { // TODO: This loop takes ~600ms for index := 0; ; index++ { // TODO: This loop takes ~600ms
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -268,13 +277,13 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match. // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
if driverData.IsNewer(driverDate, driverVersion) { if driverData.IsNewer(driverDate, driverVersion) {
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
err := devInfoList.SetSelectedDriver(deviceData, driverData) err := devInfo.SetSelectedDriver(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
@@ -299,10 +308,10 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil { if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
// Call appropriate class installer. // Call appropriate class installer.
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil { if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData) rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
} }
} }
@@ -311,14 +320,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}() }()
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return return
} }
// Register device co-installers if any. (Ignore errors) // Register device co-installers if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData) devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
var netDevRegKey registry.Key var netDevRegKey registry.Key
const pollTimeout = time.Millisecond * 50 const pollTimeout = time.Millisecond * 50
@@ -326,7 +335,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
if i != 0 { if i != 0 {
time.Sleep(pollTimeout) time.Sleep(pollTimeout)
} }
netDevRegKey, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY) netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil { if err == nil {
break break
} }
@@ -345,17 +354,17 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
} }
// Install interfaces if any. (Ignore errors) // Install interfaces if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData) devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
// Install the device. // Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
return return
} }
rebootRequired = checkReboot(devInfoList, deviceData) rebootRequired = checkReboot(devInfo, devInfoData)
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_DEVICEDESC, deviceTypeName) err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
if err != nil { if err != nil {
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
return return
@@ -381,7 +390,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
} }
// Get network interface. // Get network interface.
wintun, err = makeWintun(devInfoList, deviceData, pool) wintun, err = makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err) err = fmt.Errorf("makeWintun failed: %v", err)
return return
@@ -435,14 +444,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// if the interface was not found. It returns a bool indicating whether // if the interface was not found. It returns a bool indicating whether
// a reboot is required. // a reboot is required.
func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) { func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
devInfoList, deviceData, err := wintun.deviceData() devInfo, devInfoData, err := wintun.devInfoData()
if err == windows.ERROR_OBJECT_NOT_FOUND { if err == windows.ERROR_OBJECT_NOT_FOUND {
return false, nil return false, nil
} }
if err != nil { if err != nil {
return false, err return false, err
} }
defer devInfoList.Close() defer devInfo.Close()
// Remove the device. // Remove the device.
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
@@ -451,18 +460,18 @@ func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err) return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
} }
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil { if err != nil {
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err) return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
} }
return checkReboot(devInfoList, deviceData), nil return checkReboot(devInfo, devInfoData), nil
} }
// DeleteMatchingInterfaces deletes all Wintun interfaces, which match // DeleteMatchingInterfaces deletes all Wintun interfaces, which match
@@ -479,14 +488,14 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
windows.CloseHandle(mutex) windows.CloseHandle(mutex)
}() }()
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())} return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
} }
defer devInfoList.Close() defer devInfo.Close()
for i := 0; ; i++ { for i := 0; ; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i) devInfoData, err := devInfo.EnumDeviceInfo(i)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -494,22 +503,31 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil { if err != nil {
continue continue
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
continue
}
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
isWintun := false isWintun := false
for j := 0; ; j++ { for j := 0; ; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
@@ -522,7 +540,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
isMember, err := pool.isMember(devInfoList, deviceData) isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
@@ -531,7 +549,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
wintun, err := makeWintun(devInfoList, deviceData, pool) wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err)) errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
continue continue
@@ -540,41 +558,41 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
inst := deviceData.DevInst inst := devInfoData.DevInst
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
} }
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData) rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
deviceInstancesDeleted = append(deviceInstancesDeleted, inst) deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
} }
return return
} }
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name. // isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (bool, error) { func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
deviceDescVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_DEVICEDESC) deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
if err != nil { if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
} }
deviceDesc, _ := deviceDescVal.(string) deviceDesc, _ := deviceDescVal.(string)
friendlyNameVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_FRIENDLYNAME) friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
if err != nil { if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err) return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
} }
@@ -585,8 +603,8 @@ func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupa
} }
// checkReboot checks device install parameters if a system reboot is required. // checkReboot checks device install parameters if a system reboot is required.
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool { func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil { if err != nil {
return false return false
} }
@@ -595,14 +613,14 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
} }
// setQuietInstall sets device install parameters for a quiet installation // setQuietInstall sets device install parameters for a quiet installation
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error { func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil { if err != nil {
return err return err
} }
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams) return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
} }
// deviceTypeName returns pool-specific device type name. // deviceTypeName returns pool-specific device type name.
@@ -721,18 +739,18 @@ func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
} }
// deviceData returns TUN device info list handle and interface device info // devInfoData returns TUN device info list handle and interface device info
// data. The device info list handle must be closed after use. In case the // data. The device info list handle must be closed after use. In case the
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned. // device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) { func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error()) return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
} }
for index := 0; ; index++ { for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index) devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -742,22 +760,22 @@ func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData,
// Get interface ID. // Get interface ID.
// TODO: Store some ID in the Wintun object such that this call isn't required. // TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData, wintun.pool) wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
if err != nil { if err != nil {
continue continue
} }
if wintun.cfgInstanceID == wintun2.cfgInstanceID { if wintun.cfgInstanceID == wintun2.cfgInstanceID {
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
devInfoList.Close() devInfo.Close()
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err) return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
} }
return devInfoList, deviceData, nil return devInfo, devInfoData, nil
} }
} }
devInfoList.Close() devInfo.Close()
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
} }