31 Commits

Author SHA1 Message Date
Jason A. Donenfeld
ae88e2a2cd version: bump snapshot 2020-03-20 12:00:53 -06:00
Jason A. Donenfeld
4739708ca4 noise: unify zero checking of ecdh 2020-03-17 23:07:14 -06:00
Tobias Klauser
b33219c2cf global: use RTMGRP_* consts from x/sys/unix
Update the golang.org/x/sys/unix dependency and use the newly introduced
RTMGRP_* consts instead of using the corresponding RTNLGRP_* const to
create a mask.

Signed-off-by: Tobias Klauser <tklauser@distanz.ch>
2020-03-17 23:07:11 -06:00
Jason A. Donenfeld
9cbcff10dd send: account for zero mtu
Don't divide by zero.
2020-02-14 18:53:55 +01:00
Jason A. Donenfeld
6ed56ff2df device: fix private key removal logic 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
cb4bb63030 uapi: allow unsetting device private key with /dev/null 2020-02-04 22:02:53 +01:00
Jason A. Donenfeld
05b03c6750 version: bump snapshot 2020-01-21 16:27:19 +01:00
Jason A. Donenfeld
caebdfe9d0 tun: darwin: ignore ENOMEM errors
Coauthored-by: Andrej Mihajlov <and@mullvad.net>
2020-01-15 13:39:37 -05:00
Jason A. Donenfeld
4fa2ea6a2d tun: windows: serialize write calls 2020-01-07 11:40:45 -05:00
Jason A. Donenfeld
89dd065e53 README: update repo urls 2019-12-30 11:53:39 +01:00
Jason A. Donenfeld
ddfad453cf device: SendmsgN mutates the input sockaddr
So we take a new granular lock to prevent concurrent writes from
racing.

WARNING: DATA RACE
Write at 0x00c0011f2740 by goroutine 27:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.(*Device).RoutineReadFromTUN()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:318
+0x4b8

Previous write at 0x00c0011f2740 by goroutine 386:
  golang.org/x/sys/unix.(*SockaddrInet4).sockaddr()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:384
+0x114
  golang.org/x/sys/unix.SendmsgN()
      /go/pkg/mod/golang.org/x/sys@v0.0.0-20191105231009-c1f44814a5cd/unix/syscall_linux.go:1304
+0x288
  golang.zx2c4.com/wireguard/device.send4()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:485
+0x11f
  golang.zx2c4.com/wireguard/device.(*nativeBind).Send()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/conn_linux.go:268
+0x1d6
  golang.zx2c4.com/wireguard/device.(*Peer).SendBuffer()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/peer.go:151
+0x285
  golang.zx2c4.com/wireguard/device.(*Peer).SendHandshakeInitiation()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/send.go:163
+0x692
  golang.zx2c4.com/wireguard/device.expiredRetransmitHandshake()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:110
+0x40c
  golang.zx2c4.com/wireguard/device.(*Peer).NewTimer.func1()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/timers.go:42
+0xd8

Goroutine 27 (running) created at:
  golang.zx2c4.com/wireguard/device.NewDevice()
      /go/pkg/mod/golang.zx2c4.com/wireguard@v0.0.20191012/device/device.go:322
+0x5e8
  main.main()
      /go/src/x/main.go:102 +0x58e

Goroutine 386 (finished) created at:
  time.goFunc()
      /usr/local/go/src/time/sleep.go:168 +0x51

Reported-by: Ben Burkert <ben@benburkert.com>
2019-11-28 11:11:13 +01:00
Jason A. Donenfeld
2b242f9393 wintun: manage ring memory manually
It's large and Go's garbage collector doesn't deal with it especially
well.
2019-11-22 13:13:55 +01:00
Jason A. Donenfeld
4cdf805b29 constants: recalculate rekey max based on a one minute flood
Discussed-with: Mathias Hall-Andersen <mathias@hall-andersen.dk>
2019-10-30 14:29:32 +01:00
Jonathan Tooker
f7d0edd2ec global: fix a few typos courtesy of codespell
Signed-off-by: Jonathan Tooker <jonathan.tooker@netprotect.com>
2019-10-22 11:51:25 +02:00
Jason A. Donenfeld
ffffbbcc8a device: allow blackholing sockets 2019-10-21 13:29:57 +02:00
Jason A. Donenfeld
47b02c618b device: remove dead error reporting code 2019-10-21 11:46:54 +02:00
Jason A. Donenfeld
fd23c66fcd namespaceapi: remove tasteless comment 2019-10-21 09:02:29 +02:00
Jason A. Donenfeld
ae492d1b35 device: recheck counters while holding write lock 2019-10-17 15:43:06 +02:00
Jason A. Donenfeld
95fbfccf60 wintun: normalize variable names for their types 2019-10-17 15:30:56 +02:00
Avery Pennarun
c85e4a410f wintun: quickly ignore non-Wintun devices
Some devices take ~2 seconds to enumerate on Windows if we try to get
their instance name.  The hardware id property, on the other hand,
is available right away.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: inlined this to where it makes sense, reused setupapi const]
2019-10-17 15:19:20 +02:00
Avery Pennarun
1b6c8ddbe8 tun: match windows CreateTUN signature to the Linux variant
Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
[zx2c4: fix default value]
2019-10-17 15:19:20 +02:00
Avery Pennarun
0abb6b668c rwcancel: handle EINTR and EAGAIN in unixSelect()
On my Chromebook (Linux 4.19.44 in a VM) and on an AWS EC2
machine, select() was sometimes returning EINTR. This is
harmless and just means you should try again. So let's try
again.

This eliminates a problem where the tunnel fails to come up
correctly and the program needs to be restarted.

Signed-off-by: Avery Pennarun <apenwarr@gmail.com>
2019-10-17 15:19:17 +02:00
David Crawshaw
540d01e54a device: test packets between two fake devices
Signed-off-by: David Crawshaw <crawshaw@tailscale.io>
2019-10-16 11:38:28 +02:00
Jason A. Donenfeld
f2ea85e9f9 version: bump snapshot 2019-10-12 22:34:10 +02:00
Jason A. Donenfeld
222f0f8000 Makefile: remove v prefix 2019-10-08 16:48:18 +02:00
Jason A. Donenfeld
1f146a5e7a wintun: expose version 2019-10-08 09:58:58 +02:00
Jason A. Donenfeld
f2501aa6c8 uapi: allow preventing creation of new peers when updating
This enables race-free updates for wg-dynamic and similar tools.

Suggested-by: Thomas Gschwantner <tharre3@gmail.com>
2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
cb8d01f58a mod: bump versions 2019-10-04 11:41:02 +02:00
Jason A. Donenfeld
01f8ef4e84 winpipe: use x/sys/windows instead of syscall 2019-09-16 23:39:16 -06:00
Jason A. Donenfeld
70f6c42556 wintun: use correct length for security attributes 2019-09-16 19:38:33 -06:00
Jason A. Donenfeld
bb0b2514c0 tun: windows: unify error message format 2019-09-08 13:52:44 -05:00
38 changed files with 692 additions and 482 deletions

View File

@@ -10,7 +10,7 @@ MAKEFLAGS += --no-print-directory
generate-version-and-build: generate-version-and-build:
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \ @export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
tag="$$(git describe --dirty 2>/dev/null)" && \ tag="$$(git describe --dirty 2>/dev/null)" && \
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \ ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \ [ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
echo "$$ver" > device/version.go && \ echo "$$ver" > device/version.go && \
git update-index --assume-unchanged device/version.go || true git update-index --assume-unchanged device/version.go || true

View File

@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
$ wireguard-go -f wg0 $ wireguard-go -f wg0
``` ```
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands. When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
To run with more logging you may set the environment variable `LOG_LEVEL=debug`. To run with more logging you may set the environment variable `LOG_LEVEL=debug`.

View File

@@ -18,7 +18,7 @@ const (
sockoptIPV6_UNICAST_IF = 31 sockoptIPV6_UNICAST_IF = 31
) )
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error { func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */ /* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
bytes := make([]byte, 4) bytes := make([]byte, 4)
binary.BigEndian.PutUint32(bytes, interfaceIndex) binary.BigEndian.PutUint32(bytes, interfaceIndex)
@@ -41,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole4 = blackhole
return nil return nil
} }
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error { func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn() sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err
@@ -58,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
if err != nil { if err != nil {
return err return err
} }
device.net.bind.(*nativeBind).blackhole6 = blackhole
return nil return nil
} }

View File

@@ -21,8 +21,10 @@ import (
*/ */
type nativeBind struct { type nativeBind struct {
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
} }
type NativeEndpoint net.UDPAddr type NativeEndpoint net.UDPAddr
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
if bind.ipv4 == nil { if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole4 {
return nil
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else { } else {
if bind.ipv6 == nil { if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole6 {
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend)) _, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
} }
return err return err

View File

@@ -7,7 +7,7 @@
* This implements userspace semantics of "sticky sockets", modeled after * This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port * WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code: * of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c * https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
* *
* Currently there is no way to achieve this within the net package: * Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930 * See e.g. https://github.com/golang/go/issues/17930
@@ -43,6 +43,7 @@ type IPv6Source struct {
} }
type NativeEndpoint struct { type NativeEndpoint struct {
sync.Mutex
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool isV6 bool
@@ -117,7 +118,7 @@ func createNetlinkRouteSocket() (int, error) {
} }
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)), Groups: unix.RTMGRP_IPV4_ROUTE,
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
go bind.routineRouteListener(device) go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful // attempt ipv6 bind, update port if successful
bind.sock6, newPort, err = create6(port) bind.sock6, newPort, err = create6(port)
if err != nil { if err != nil {
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
port = newPort port = newPort
} }
// attempt ipv4 bind, update port if succesful // attempt ipv4 bind, update port if successful
bind.sock4, newPort, err = create4(port) bind.sock4, newPort, err = create4(port)
if err != nil { if err != nil {
@@ -482,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
}, },
} }
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -493,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{} cmsg.pktinfo = unix.Inet4Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
end.Unlock()
} }
return err return err
@@ -522,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
cmsg.pktinfo.Ifindex = 0 cmsg.pktinfo.Ifindex = 0
} }
end.Lock()
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
if err == nil { if err == nil {
return nil return nil
@@ -533,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
if err == unix.EINVAL { if err == unix.EINVAL {
end.ClearSrc() end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{} cmsg.pktinfo = unix.Inet6Pktinfo{}
end.Lock()
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0) _, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
end.Unlock()
} }
return err return err
@@ -541,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // construct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr
@@ -573,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) { func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header // construct message header
var cmsg struct { var cmsg struct {
cmsghdr unix.Cmsghdr cmsghdr unix.Cmsghdr

View File

@@ -12,7 +12,7 @@ import (
/* Specification constants */ /* Specification constants */
const ( const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1 RekeyAfterMessages = (1 << 60)
RejectAfterMessages = (1 << 64) - (1 << 4) - 1 RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120 RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90 RekeyAttemptTime = time.Second * 90

View File

@@ -236,23 +236,11 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// do static-static DH pre-computations // do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap)) expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
for key, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
handshake := &peer.handshake handshake := &peer.handshake
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
if rmKey { expiredPeers = append(expiredPeers, peer)
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
} else {
expiredPeers = append(expiredPeers, peer)
}
} }
for _, peer := range lockedPeers { for _, peer := range lockedPeers {

View File

@@ -5,54 +5,212 @@
package device package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import ( import (
"bufio"
"bytes" "bytes"
"encoding/binary"
"io"
"net"
"os"
"strings"
"testing" "testing"
"time"
"golang.zx2c4.com/wireguard/tun"
) )
func TestDevice(t *testing.T) { func TestTwoDevicePing(t *testing.T) {
// TODO(crawshaw): pick unused ports on localhost
// prepare tun devices for generating traffic cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
listen_port=53511
tun1 := newDummyTUN("tun1") replace_peers=true
tun2 := newDummyTUN("tun2") public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
protocol_version=1
_ = tun1 replace_allowed_ips=true
_ = tun2 allowed_ip=1.0.0.2/32
endpoint=127.0.0.1:53512`
// prepare endpoints tun1 := NewChannelTUN()
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
end1, err := CreateDummyEndpoint() dev1.Up()
if err != nil { defer dev1.Close()
t.Error("failed to create endpoint:", err.Error()) if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err) t.Fatal(err)
} }
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "") cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
device := NewDevice(tun, logger) listen_port=53512
device.SetPrivateKey(sk) replace_peers=true
return device public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
protocol_version=1
replace_allowed_ips=true
allowed_ip=1.0.0.1/32
endpoint=127.0.0.1:53511`
tun2 := NewChannelTUN()
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
dev2.Up()
defer dev2.Close()
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
t.Fatal(err)
}
t.Run("ping 1.0.0.1", func(t *testing.T) {
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
tun2.Outbound <- msg2to1
select {
case msgRecv := <-tun1.Inbound:
if !bytes.Equal(msg2to1, msgRecv) {
t.Error("ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("ping did not transit")
}
})
t.Run("ping 1.0.0.2", func(t *testing.T) {
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
tun1.Outbound <- msg1to2
select {
case msgRecv := <-tun2.Inbound:
if !bytes.Equal(msg1to2, msgRecv) {
t.Error("return ping did not transit correctly")
}
case <-time.After(300 * time.Millisecond):
t.Error("return ping did not transit")
}
})
}
func ping(dst, src net.IP) []byte {
localPort := uint16(1337)
seq := uint16(0)
payload := make([]byte, 4)
binary.BigEndian.PutUint16(payload[0:], localPort)
binary.BigEndian.PutUint16(payload[2:], seq)
return genICMPv4(payload, dst, src)
}
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
func checksum(buf []byte, initial uint16) uint16 {
v := uint32(initial)
for i := 0; i < len(buf)-1; i += 2 {
v += uint32(binary.BigEndian.Uint16(buf[i:]))
}
if len(buf)%2 == 1 {
v += uint32(buf[len(buf)-1]) << 8
}
for v > 0xffff {
v = (v >> 16) + (v & 0xffff)
}
return ^uint16(v)
}
func genICMPv4(payload []byte, dst, src net.IP) []byte {
const (
icmpv4ProtocolNumber = 1
icmpv4Echo = 8
icmpv4ChecksumOffset = 2
icmpv4Size = 8
ipv4Size = 20
ipv4TotalLenOffset = 2
ipv4ChecksumOffset = 10
ttl = 65
)
hdr := make([]byte, ipv4Size+icmpv4Size)
ip := hdr[0:ipv4Size]
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
// https://tools.ietf.org/html/rfc792
icmpv4[0] = icmpv4Echo // type
icmpv4[1] = 0 // code
chksum := ^checksum(icmpv4, checksum(payload, 0))
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
// https://tools.ietf.org/html/rfc760 section 3.1
length := uint16(len(hdr) + len(payload))
ip[0] = (4 << 4) | (ipv4Size / 4)
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
ip[8] = ttl
ip[9] = icmpv4ProtocolNumber
copy(ip[12:], src.To4())
copy(ip[16:], dst.To4())
chksum = ^checksum(ip[:], 0)
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
var v []byte
v = append(v, hdr...)
v = append(v, payload...)
return []byte(v)
}
// TODO(crawshaw): find a reusable home for this. package devicetest?
type ChannelTUN struct {
Inbound chan []byte // incoming packets, closed on TUN close
Outbound chan []byte // outbound packets, blocks forever on TUN close
closed chan struct{}
events chan tun.Event
tun chTun
}
func NewChannelTUN() *ChannelTUN {
c := &ChannelTUN{
Inbound: make(chan []byte),
Outbound: make(chan []byte),
closed: make(chan struct{}),
events: make(chan tun.Event, 1),
}
c.tun.c = c
c.events <- tun.EventUp
return c
}
func (c *ChannelTUN) TUN() tun.Device {
return &c.tun
}
type chTun struct {
c *ChannelTUN
}
func (t *chTun) File() *os.File { return nil }
func (t *chTun) Read(data []byte, offset int) (int, error) {
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case msg := <-t.c.Outbound:
return copy(data[offset:], msg), nil
}
}
// Write is called by the wireguard device to deliver a packet for routing.
func (t *chTun) Write(data []byte, offset int) (int, error) {
if offset == -1 {
close(t.c.closed)
close(t.c.events)
return 0, io.EOF
}
msg := make([]byte, len(data)-offset)
copy(msg, data[offset:])
select {
case <-t.c.closed:
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
case t.c.Inbound <- msg:
return len(data) - offset, nil
}
}
func (t *chTun) Flush() error { return nil }
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
func (t *chTun) Events() chan tun.Event { return t.c.events }
func (t *chTun) Close() error {
t.Write(nil, -1)
return nil
} }
func assertNil(t *testing.T, err error) { func assertNil(t *testing.T, err error) {
@@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) {
t.Fatal(a, "!=", b) t.Fatal(a, "!=", b)
} }
} }
func randDevice(t *testing.T) *Device {
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
tun := newDummyTUN("dummy")
logger := NewLogger(LogLevelError, "")
device := NewDevice(tun, logger)
device.SetPrivateKey(sk)
return device
}

View File

@@ -39,13 +39,13 @@ const (
) )
const ( const (
MessageInitiationSize = 148 // size of handshake initation message MessageInitiationSize = 148 // size of handshake initiation message
MessageResponseSize = 92 // size of response message MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message MessageTransportHeaderSize = 16 // size of data preceding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
) )
const ( const (
@@ -154,6 +154,7 @@ func init() {
} }
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) { func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
var errZeroECDHResult = errors.New("ECDH returned all zeros")
device.staticIdentity.RLock() device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock() defer device.staticIdentity.RUnlock()
@@ -162,12 +163,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
handshake.mutex.Lock() handshake.mutex.Lock()
defer handshake.mutex.Unlock() defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key // create ephemeral key
var err error var err error
handshake.hash = InitialHash handshake.hash = InitialHash
handshake.chainKey = InitialChainKey handshake.chainKey = InitialChainKey
@@ -176,56 +172,53 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
return nil, err return nil, err
} }
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:]) handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{ msg := MessageInitiation{
Type: MessageInitiationType, Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(), Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
} }
handshake.mixKey(msg.Ephemeral[:]) handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:]) handshake.mixHash(msg.Ephemeral[:])
// encrypt static key // encrypt static key
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
func() { if isZero(ss[:]) {
var key [chacha20poly1305.KeySize]byte return nil, errZeroECDHResult
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic) }
KDF2( var key [chacha20poly1305.KeySize]byte
&handshake.chainKey, KDF2(
&key, &handshake.chainKey,
handshake.chainKey[:], &key,
ss[:], handshake.chainKey[:],
) ss[:],
aead, _ := chacha20poly1305.New(key[:]) )
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:]) aead, _ := chacha20poly1305.New(key[:])
}() aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
handshake.mixHash(msg.Static[:]) handshake.mixHash(msg.Static[:])
// encrypt timestamp // encrypt timestamp
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errZeroECDHResult
}
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
timestamp := tai64n.Now() timestamp := tai64n.Now()
func() { aead, _ = chacha20poly1305.New(key[:])
var key [chacha20poly1305.KeySize]byte aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
KDF2(
&handshake.chainKey, // assign index
&key, device.indexTable.Delete(handshake.localIndex)
handshake.chainKey[:], msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
handshake.precomputedStaticStatic[:], if err != nil {
) return nil, err
aead, _ := chacha20poly1305.New(key[:]) }
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:]) handshake.localIndex = msg.Sender
}()
handshake.mixHash(msg.Timestamp[:]) handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated handshake.state = HandshakeInitiationCreated
@@ -250,16 +243,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key // decrypt static key
var err error var err error
var peerPK NoisePublicKey var peerPK NoisePublicKey
func() { var key [chacha20poly1305.KeySize]byte
var key [chacha20poly1305.KeySize]byte ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) if isZero(ss[:]) {
KDF2(&chainKey, &key, chainKey[:], ss[:]) return nil
aead, _ := chacha20poly1305.New(key[:]) }
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:]) KDF2(&chainKey, &key, chainKey[:], ss[:])
}() aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
if err != nil { if err != nil {
return nil return nil
} }
@@ -273,23 +266,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
} }
handshake := &peer.handshake handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity // verify identity
var timestamp tai64n.Timestamp var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock() handshake.mutex.RLock()
if isZero(handshake.precomputedStaticStatic[:]) {
handshake.mutex.RUnlock()
return nil
}
KDF2( KDF2(
&chainKey, &chainKey,
&key, &key,
chainKey[:], chainKey[:],
handshake.precomputedStaticStatic[:], handshake.precomputedStaticStatic[:],
) )
aead, _ := chacha20poly1305.New(key[:]) aead, _ = chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:]) _, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil { if err != nil {
handshake.mutex.RUnlock() handshake.mutex.RUnlock()
@@ -315,8 +309,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
handshake.chainKey = chainKey handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp if timestamp.After(handshake.lastTimestamp) {
handshake.lastInitiationConsumption = time.Now() handshake.lastTimestamp = timestamp
}
now := time.Now()
if now.After(handshake.lastInitiationConsumption) {
handshake.lastInitiationConsumption = now
}
handshake.state = HandshakeInitiationConsumed handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock() handshake.mutex.Unlock()

View File

@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
return return
} }
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
err = loadExactHex(key[:], src)
if key.IsZero() {
return
}
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string { func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:]) return hex.EncodeToString(key[:])
} }

View File

@@ -108,7 +108,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
handshake := &peer.handshake handshake := &peer.handshake
handshake.mutex.Lock() handshake.mutex.Lock()
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk) handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
handshake.remoteStatic = pk handshake.remoteStatic = pk
handshake.mutex.Unlock() handshake.mutex.Unlock()
@@ -116,13 +115,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
peer.endpoint = nil peer.endpoint = nil
// conditionally add // add
if !ssIsZero { device.peers.keyMap[pk] = peer
device.peers.keyMap[pk] = peer
} else {
return nil, nil
}
// start peer // start peer

View File

@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
writer := bytes.NewBuffer(buff[:0]) writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply) binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint) device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil { return nil
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
} }
func (peer *Peer) keepKeyFreshSending() { func (peer *Peer) keepKeyFreshSending() {
@@ -518,10 +515,18 @@ func (device *Device) RoutineEncryption() {
// pad content to multiple of 16 // pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu)) mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu var paddedSize int
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1) if mtu == 0 {
if paddedSize > mtu { paddedSize = (len(elem.packet) + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
paddedSize = mtu } else {
lastUnit := len(elem.packet)
if lastUnit > mtu {
lastUnit %= mtu
}
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
} }
for i := len(elem.packet); i < paddedSize; i++ { for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0) elem.packet = append(elem.packet, 0)

View File

@@ -113,6 +113,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
var peer *Peer var peer *Peer
dummy := false dummy := false
createdNewPeer := false
deviceConfig := true deviceConfig := true
for scanner.Scan() { for scanner.Scan() {
@@ -137,7 +138,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
switch key { switch key {
case "private_key": case "private_key":
var sk NoisePrivateKey var sk NoisePrivateKey
err := sk.FromHex(value) err := sk.FromMaybeZeroHex(value)
if err != nil { if err != nil {
logError.Println("Failed to set private_key:", err) logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid} return &IPCError{ipc.IpcErrorInvalid}
@@ -237,7 +238,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
peer = device.LookupPeer(publicKey) peer = device.LookupPeer(publicKey)
} }
if peer == nil { createdNewPeer = peer == nil
if createdNewPeer {
peer, err = device.NewPeer(publicKey) peer, err = device.NewPeer(publicKey)
if err != nil { if err != nil {
logError.Println("Failed to create new peer:", err) logError.Println("Failed to create new peer:", err)
@@ -251,6 +253,20 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
} }
} }
case "update_only":
// allow disabling of creation
if value != "true" {
logError.Println("Failed to set update only, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if createdNewPeer && !dummy {
device.RemovePeer(peer.handshake.remoteStatic)
peer = &Peer{}
dummy = true
}
case "remove": case "remove":
// remove currently selected peer from device // remove currently selected peer from device

View File

@@ -1,3 +1,3 @@
package device package device
const WireGuardGoVersion = "0.0.20190908" const WireGuardGoVersion = "0.0.20200320"

6
go.mod
View File

@@ -3,8 +3,8 @@ module golang.zx2c4.com/wireguard
go 1.12 go 1.12
require ( require (
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 golang.org/x/net v0.0.0-20191003171128-d98b1b443823
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527
golang.org/x/text v0.3.2 golang.org/x/text v0.3.2
) )

12
go.sum
View File

@@ -1,13 +1,13 @@
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM= golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM= golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad h1:cCejgArrk10gX6kFqjWeLwXD7aVMqWoRpyUCaaJSggc= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527 h1:uYVVQ9WP/Ds2ROhcaGPeIdVq0RIXVLwsHlnvJ+cT1So=
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200302150141-5c8b2ff67527/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=

View File

@@ -8,6 +8,8 @@ package ipc
import ( import (
"net" "net"
"golang.org/x/sys/windows"
"golang.zx2c4.com/wireguard/ipc/winpipe" "golang.zx2c4.com/wireguard/ipc/winpipe"
) )
@@ -47,8 +49,16 @@ func (l *UAPIListener) Addr() net.Addr {
return l.listener.Addr() return l.listener.Addr()
} }
/* SDDL_DEVOBJ_SYS_ALL from the WDK */ var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
func init() {
var err error
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil {
panic(err)
}
}
func UAPIListen(name string) (net.Listener, error) { func UAPIListen(name string) (net.Listener, error) {
config := winpipe.PipeConfig{ config := winpipe.PipeConfig{

View File

@@ -13,15 +13,16 @@ import (
"runtime" "runtime"
"sync" "sync"
"sync/atomic" "sync/atomic"
"syscall"
"time" "time"
"golang.org/x/sys/windows"
) )
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx //sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort //sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus //sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes //sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult //sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
type atomicBool int32 type atomicBool int32
@@ -55,7 +56,7 @@ func (e *timeoutError) Temporary() bool { return true }
type timeoutChan chan struct{} type timeoutChan chan struct{}
var ioInitOnce sync.Once var ioInitOnce sync.Once
var ioCompletionPort syscall.Handle var ioCompletionPort windows.Handle
// ioResult contains the result of an asynchronous IO operation // ioResult contains the result of an asynchronous IO operation
type ioResult struct { type ioResult struct {
@@ -65,12 +66,12 @@ type ioResult struct {
// ioOperation represents an outstanding asynchronous Win32 IO // ioOperation represents an outstanding asynchronous Win32 IO
type ioOperation struct { type ioOperation struct {
o syscall.Overlapped o windows.Overlapped
ch chan ioResult ch chan ioResult
} }
func initIo() { func initIo() {
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff) h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
if err != nil { if err != nil {
panic(err) panic(err)
} }
@@ -81,7 +82,7 @@ func initIo() {
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall. // win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
// It takes ownership of this handle and will close it if it is garbage collected. // It takes ownership of this handle and will close it if it is garbage collected.
type win32File struct { type win32File struct {
handle syscall.Handle handle windows.Handle
wg sync.WaitGroup wg sync.WaitGroup
wgLock sync.RWMutex wgLock sync.RWMutex
closing atomicBool closing atomicBool
@@ -99,7 +100,7 @@ type deadlineHandler struct {
} }
// makeWin32File makes a new win32File from an existing file handle // makeWin32File makes a new win32File from an existing file handle
func makeWin32File(h syscall.Handle) (*win32File, error) { func makeWin32File(h windows.Handle) (*win32File, error) {
f := &win32File{handle: h} f := &win32File{handle: h}
ioInitOnce.Do(initIo) ioInitOnce.Do(initIo)
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff) _, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
@@ -115,7 +116,7 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
return f, nil return f, nil
} }
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) { func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
return makeWin32File(h) return makeWin32File(h)
} }
@@ -129,7 +130,7 @@ func (f *win32File) closeHandle() {
cancelIoEx(f.handle, nil) cancelIoEx(f.handle, nil)
f.wg.Wait() f.wg.Wait()
// at this point, no new IO can start // at this point, no new IO can start
syscall.Close(f.handle) windows.Close(f.handle)
f.handle = 0 f.handle = 0
} else { } else {
f.wgLock.Unlock() f.wgLock.Unlock()
@@ -158,12 +159,12 @@ func (f *win32File) prepareIo() (*ioOperation, error) {
} }
// ioCompletionProcessor processes completed async IOs forever // ioCompletionProcessor processes completed async IOs forever
func ioCompletionProcessor(h syscall.Handle) { func ioCompletionProcessor(h windows.Handle) {
for { for {
var bytes uint32 var bytes uint32
var key uintptr var key uintptr
var op *ioOperation var op *ioOperation
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE) err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
if op == nil { if op == nil {
panic(err) panic(err)
} }
@@ -174,7 +175,7 @@ func ioCompletionProcessor(h syscall.Handle) {
// asyncIo processes the return value from ReadFile or WriteFile, blocking until // asyncIo processes the return value from ReadFile or WriteFile, blocking until
// the operation has actually completed. // the operation has actually completed.
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) { func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
if err != syscall.ERROR_IO_PENDING { if err != windows.ERROR_IO_PENDING {
return int(bytes), err return int(bytes), err
} }
@@ -193,7 +194,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
select { select {
case r = <-c.ch: case r = <-c.ch:
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
if f.closing.isSet() { if f.closing.isSet() {
err = ErrFileClosed err = ErrFileClosed
} }
@@ -206,7 +207,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
cancelIoEx(f.handle, &c.o) cancelIoEx(f.handle, &c.o)
r = <-c.ch r = <-c.ch
err = r.err err = r.err
if err == syscall.ERROR_OPERATION_ABORTED { if err == windows.ERROR_OPERATION_ABORTED {
err = ErrTimeout err = ErrTimeout
} }
} }
@@ -231,14 +232,14 @@ func (f *win32File) Read(b []byte) (int, error) {
} }
var bytes uint32 var bytes uint32
err = syscall.ReadFile(f.handle, b, &bytes, &c.o) err = windows.ReadFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.readDeadline, bytes, err) n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
// Handle EOF conditions. // Handle EOF conditions.
if err == nil && n == 0 && len(b) != 0 { if err == nil && n == 0 && len(b) != 0 {
return 0, io.EOF return 0, io.EOF
} else if err == syscall.ERROR_BROKEN_PIPE { } else if err == windows.ERROR_BROKEN_PIPE {
return 0, io.EOF return 0, io.EOF
} else { } else {
return n, err return n, err
@@ -258,7 +259,7 @@ func (f *win32File) Write(b []byte) (int, error) {
} }
var bytes uint32 var bytes uint32
err = syscall.WriteFile(f.handle, b, &bytes, &c.o) err = windows.WriteFile(f.handle, b, &bytes, &c.o)
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err) n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
runtime.KeepAlive(b) runtime.KeepAlive(b)
return n, err return n, err
@@ -273,7 +274,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
} }
func (f *win32File) Flush() error { func (f *win32File) Flush() error {
return syscall.FlushFileBuffers(f.handle) return windows.FlushFileBuffers(f.handle)
} }
func (f *win32File) Fd() uintptr { func (f *win32File) Fd() uintptr {

View File

@@ -6,4 +6,4 @@
package winpipe package winpipe
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go

View File

@@ -16,18 +16,19 @@ import (
"net" "net"
"os" "os"
"runtime" "runtime"
"syscall"
"time" "time"
"unsafe" "unsafe"
"golang.org/x/sys/windows"
) )
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe //sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW //sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW //sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo //sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc //sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile //sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb //sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U //sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl //sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
@@ -41,7 +42,7 @@ type objectAttributes struct {
RootDirectory uintptr RootDirectory uintptr
ObjectName *unicodeString ObjectName *unicodeString
Attributes uintptr Attributes uintptr
SecurityDescriptor *securityDescriptor SecurityDescriptor *windows.SECURITY_DESCRIPTOR
SecurityQoS uintptr SecurityQoS uintptr
} }
@@ -51,16 +52,6 @@ type unicodeString struct {
Buffer uintptr Buffer uintptr
} }
type securityDescriptor struct {
Revision byte
Sbz1 byte
Control uint16
Owner uintptr
Group uintptr
Sacl uintptr
Dacl uintptr
}
type ntstatus int32 type ntstatus int32
func (status ntstatus) Err() error { func (status ntstatus) Err() error {
@@ -71,11 +62,6 @@ func (status ntstatus) Err() error {
} }
const ( const (
cERROR_PIPE_BUSY = syscall.Errno(231)
cERROR_NO_DATA = syscall.Errno(232)
cERROR_PIPE_CONNECTED = syscall.Errno(535)
cERROR_SEM_TIMEOUT = syscall.Errno(121)
cSECURITY_SQOS_PRESENT = 0x100000 cSECURITY_SQOS_PRESENT = 0x100000
cSECURITY_ANONYMOUS = 0 cSECURITY_ANONYMOUS = 0
@@ -88,8 +74,6 @@ const (
cFILE_PIPE_MESSAGE_TYPE = 1 cFILE_PIPE_MESSAGE_TYPE = 1
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2 cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
cSE_DACL_PRESENT = 4
) )
var ( var (
@@ -170,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
// zero-byte message, ensure that all future Read() calls // zero-byte message, ensure that all future Read() calls
// also return EOF. // also return EOF.
f.readEOF = true f.readEOF = true
} else if err == syscall.ERROR_MORE_DATA { } else if err == windows.ERROR_MORE_DATA {
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode // ERROR_MORE_DATA indicates that the pipe's read mode is message mode
// and the message still has more bytes. Treat this as a success, since // and the message still has more bytes. Treat this as a success, since
// this package presents all named pipes as byte streams. // this package presents all named pipes as byte streams.
@@ -188,17 +172,17 @@ func (s pipeAddress) String() string {
} }
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout. // tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) { func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return syscall.Handle(0), ctx.Err() return windows.Handle(0), ctx.Err()
default: default:
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0) h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
if err == nil { if err == nil {
return h, nil return h, nil
} }
if err != cERROR_PIPE_BUSY { if err != windows.ERROR_PIPE_BUSY {
return h, &os.PathError{Err: err, Op: "open", Path: *path} return h, &os.PathError{Err: err, Op: "open", Path: *path}
} }
// Wait 10 msec and try again. This is a rather simplistic // Wait 10 msec and try again. This is a rather simplistic
@@ -211,7 +195,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
// DialPipe connects to a named pipe by path, timing out if the connection // DialPipe connects to a named pipe by path, timing out if the connection
// takes longer than the specified duration. If timeout is nil, then we use // takes longer than the specified duration. If timeout is nil, then we use
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.) // a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) { func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
var absTimeout time.Time var absTimeout time.Time
if timeout != nil { if timeout != nil {
absTimeout = time.Now().Add(*timeout) absTimeout = time.Now().Add(*timeout)
@@ -228,39 +212,41 @@ func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx` // DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
// cancellation or timeout. // cancellation or timeout.
func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) { func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
var err error var err error
var h syscall.Handle var h windows.Handle
h, err = tryDialPipe(ctx, &path) h, err = tryDialPipe(ctx, &path)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if expectedOwner != nil { if expectedOwner != nil {
var realOwner *syscall.SID sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
var realSd uintptr
err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
defer localFree(realSd) realOwner, _, err := sd.Owner()
if !equalSid(realOwner, expectedOwner) { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, syscall.ERROR_ACCESS_DENIED return nil, err
}
if !realOwner.Equals(expectedOwner) {
windows.Close(h)
return nil, windows.ERROR_ACCESS_DENIED
} }
} }
var flags uint32 var flags uint32
err = getNamedPipeInfo(h, &flags, nil, nil, nil) err = getNamedPipeInfo(h, &flags, nil, nil, nil)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
f, err := makeWin32File(h) f, err := makeWin32File(h)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
@@ -280,7 +266,7 @@ type acceptResponse struct {
} }
type win32PipeListener struct { type win32PipeListener struct {
firstHandle syscall.Handle firstHandle windows.Handle
path string path string
config PipeConfig config PipeConfig
acceptCh chan (chan acceptResponse) acceptCh chan (chan acceptResponse)
@@ -288,8 +274,8 @@ type win32PipeListener struct {
doneCh chan int doneCh chan int
} }
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) { func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
path16, err := syscall.UTF16FromString(path) path16, err := windows.UTF16FromString(path)
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
@@ -301,31 +287,32 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil { if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
defer localFree(ntPath.Buffer) defer windows.LocalFree(windows.Handle(ntPath.Buffer))
oa.ObjectName = &ntPath oa.ObjectName = &ntPath
// The security descriptor is only needed for the first pipe. // The security descriptor is only needed for the first pipe.
if first { if first {
if sd != nil { if sd != nil {
len := uint32(len(sd)) oa.SecurityDescriptor = sd
sdb := localAlloc(0, len)
defer localFree(sdb)
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
} else { } else {
// Construct the default named pipe security descriptor. // Construct the default named pipe security descriptor.
var dacl uintptr var dacl uintptr
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil { if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
return 0, fmt.Errorf("getting default named pipe ACL: %s", err) return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
} }
defer localFree(dacl) defer windows.LocalFree(windows.Handle(dacl))
sd, err := windows.NewSecurityDescriptor()
sdb := &securityDescriptor{ if err != nil {
Revision: 1, return 0, fmt.Errorf("creating new security descriptor: %s", err)
Control: cSE_DACL_PRESENT,
Dacl: dacl,
} }
oa.SecurityDescriptor = sdb if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
return 0, fmt.Errorf("assigning dacl: %s", err)
}
sd, err = sd.ToSelfRelative()
if err != nil {
return 0, fmt.Errorf("converting to self-relative: %s", err)
}
oa.SecurityDescriptor = sd
} }
} }
@@ -335,22 +322,22 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
} }
disposition := uint32(cFILE_OPEN) disposition := uint32(cFILE_OPEN)
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE) access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
if first { if first {
disposition = cFILE_CREATE disposition = cFILE_CREATE
// By not asking for read or write access, the named pipe file system // By not asking for read or write access, the named pipe file system
// will put this pipe into an initially disconnected state, blocking // will put this pipe into an initially disconnected state, blocking
// client connections until the next call with first == false. // client connections until the next call with first == false.
access = syscall.SYNCHRONIZE access = windows.SYNCHRONIZE
} }
timeout := int64(-50 * 10000) // 50ms timeout := int64(-50 * 10000) // 50ms
var ( var (
h syscall.Handle h windows.Handle
iosb ioStatusBlock iosb ioStatusBlock
) )
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err() err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
if err != nil { if err != nil {
return 0, &os.PathError{Op: "open", Path: path, Err: err} return 0, &os.PathError{Op: "open", Path: path, Err: err}
} }
@@ -366,7 +353,7 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
} }
f, err := makeWin32File(h) f, err := makeWin32File(h)
if err != nil { if err != nil {
syscall.Close(h) windows.Close(h)
return nil, err return nil, err
} }
return f, nil return f, nil
@@ -417,7 +404,7 @@ func (l *win32PipeListener) listenerRoutine() {
p, err = l.makeConnectedServerPipe() p, err = l.makeConnectedServerPipe()
// If the connection was immediately closed by the client, try // If the connection was immediately closed by the client, try
// again. // again.
if err != cERROR_NO_DATA { if err != windows.ERROR_NO_DATA {
break break
} }
} }
@@ -425,7 +412,7 @@ func (l *win32PipeListener) listenerRoutine() {
closed = err == ErrPipeListenerClosed closed = err == ErrPipeListenerClosed
} }
} }
syscall.Close(l.firstHandle) windows.Close(l.firstHandle)
l.firstHandle = 0 l.firstHandle = 0
// Notify Close() and Accept() callers that the handle has been closed. // Notify Close() and Accept() callers that the handle has been closed.
close(l.doneCh) close(l.doneCh)
@@ -433,8 +420,8 @@ func (l *win32PipeListener) listenerRoutine() {
// PipeConfig contain configuration for the pipe listener. // PipeConfig contain configuration for the pipe listener.
type PipeConfig struct { type PipeConfig struct {
// SecurityDescriptor contains a Windows security descriptor in SDDL format. // SecurityDescriptor contains a Windows security descriptor.
SecurityDescriptor string SecurityDescriptor *windows.SECURITY_DESCRIPTOR
// MessageMode determines whether the pipe is in byte or message mode. In either // MessageMode determines whether the pipe is in byte or message mode. In either
// case the pipe is read in byte mode by default. The only practical difference in // case the pipe is read in byte mode by default. The only practical difference in
@@ -454,20 +441,10 @@ type PipeConfig struct {
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe. // ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
// The pipe must not already exist. // The pipe must not already exist.
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) { func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
var (
sd []byte
err error
)
if c == nil { if c == nil {
c = &PipeConfig{} c = &PipeConfig{}
} }
if c.SecurityDescriptor != "" { h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
if err != nil {
return nil, err
}
}
h, err := makeServerPipeHandle(path, sd, c, true)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -492,7 +469,7 @@ func connectPipe(p *win32File) error {
err = connectNamedPipe(p.handle, &c.o) err = connectNamedPipe(p.handle, &c.o)
_, err = p.asyncIo(c, nil, 0, err) _, err = p.asyncIo(c, nil, 0, err)
if err != nil && err != cERROR_PIPE_CONNECTED { if err != nil && err != windows.ERROR_PIPE_CONNECTED {
return err return err
} }
return nil return nil

View File

@@ -1,36 +0,0 @@
// +build windows
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2005 Microsoft
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package winpipe
import (
"unsafe"
)
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
//sys localFree(mem uintptr) = LocalFree
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
const (
SE_FILE_OBJECT = 1
OWNER_SECURITY_INFORMATION = 1
)
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
var sdBuffer uintptr
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
if err != nil {
return nil, err
}
defer localFree(sdBuffer)
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
return sd, nil
}

View File

@@ -39,32 +39,26 @@ func errnoErr(e syscall.Errno) error {
var ( var (
modkernel32 = windows.NewLazySystemDLL("kernel32.dll") modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
modntdll = windows.NewLazySystemDLL("ntdll.dll") modntdll = windows.NewLazySystemDLL("ntdll.dll")
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll") modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW") procCreateNamedPipeW = modkernel32.NewProc("CreateNamedPipeW")
procCreateFileW = modkernel32.NewProc("CreateFileW") procCreateFileW = modkernel32.NewProc("CreateFileW")
procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo") procGetNamedPipeInfo = modkernel32.NewProc("GetNamedPipeInfo")
procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW") procGetNamedPipeHandleStateW = modkernel32.NewProc("GetNamedPipeHandleStateW")
procLocalAlloc = modkernel32.NewProc("LocalAlloc") procLocalAlloc = modkernel32.NewProc("LocalAlloc")
procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile") procNtCreateNamedPipeFile = modntdll.NewProc("NtCreateNamedPipeFile")
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb") procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U") procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl") procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW") procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procLocalFree = modkernel32.NewProc("LocalFree") procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength") procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo") procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procEqualSid = modadvapi32.NewProc("EqualSid") procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
procSetFileCompletionNotificationModes = modkernel32.NewProc("SetFileCompletionNotificationModes")
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
) )
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) { func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0) r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -76,7 +70,7 @@ func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
return return
} }
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) { func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name) _p0, err = syscall.UTF16PtrFromString(name)
if err != nil { if err != nil {
@@ -85,10 +79,10 @@ func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances ui
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa) return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
} }
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) { func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0) r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
handle = syscall.Handle(r0) handle = windows.Handle(r0)
if handle == syscall.InvalidHandle { if handle == windows.InvalidHandle {
if e1 != 0 { if e1 != 0 {
err = errnoErr(e1) err = errnoErr(e1)
} else { } else {
@@ -98,7 +92,7 @@ func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances
return return
} }
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) { func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
var _p0 *uint16 var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(name) _p0, err = syscall.UTF16PtrFromString(name)
if err != nil { if err != nil {
@@ -107,10 +101,10 @@ func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAtt
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile) return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
} }
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) { func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0) r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
handle = syscall.Handle(r0) handle = windows.Handle(r0)
if handle == syscall.InvalidHandle { if handle == windows.InvalidHandle {
if e1 != 0 { if e1 != 0 {
err = errnoErr(e1) err = errnoErr(e1)
} else { } else {
@@ -120,7 +114,7 @@ func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityA
return return
} }
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) { func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0) r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -132,7 +126,7 @@ func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSiz
return return
} }
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) { func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0) r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -150,7 +144,7 @@ func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
return return
} }
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) { func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0) r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
status = ntstatus(r0) status = ntstatus(r0)
return return
@@ -176,53 +170,7 @@ func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
return return
} }
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) { func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
var _p0 *uint16
_p0, err = syscall.UTF16PtrFromString(str)
if err != nil {
return
}
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
}
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
if r1 == 0 {
if e1 != 0 {
err = errnoErr(e1)
} else {
err = syscall.EINVAL
}
}
return
}
func localFree(mem uintptr) {
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
return
}
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
len = uint32(r0)
return
}
func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
if r0 != 0 {
ret = syscall.Errno(r0)
}
return
}
func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
isEqual = r0 != 0
return
}
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0) r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -234,9 +182,9 @@ func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
return return
} }
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) { func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0) r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
newport = syscall.Handle(r0) newport = windows.Handle(r0)
if newport == 0 { if newport == 0 {
if e1 != 0 { if e1 != 0 {
err = errnoErr(e1) err = errnoErr(e1)
@@ -247,7 +195,7 @@ func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintpt
return return
} }
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) { func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0) r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -259,7 +207,7 @@ func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr,
return return
} }
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) { func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0) r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
if r1 == 0 { if r1 == 0 {
if e1 != 0 { if e1 != 0 {
@@ -271,7 +219,7 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
return return
} }
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) { func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
var _p0 uint32 var _p0 uint32
if wait { if wait {
_p0 = 1 _p0 = 1

View File

@@ -37,7 +37,7 @@ func main() {
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion) logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
logger.Debug.Println("Debug log enabled") logger.Debug.Println("Debug log enabled")
tun, err := tun.CreateTUN(interfaceName) tun, err := tun.CreateTUN(interfaceName, 0)
if err == nil { if err == nil {
realInterfaceName, err2 := tun.Name() realInterfaceName, err2 := tun.Name()
if err2 == nil { if err2 == nil {

View File

@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
for i := 0; i < packetsBurstable; i++ { for i := 0; i < packetsBurstable; i++ {
Add(RatelimiterResult{ Add(RatelimiterResult{
allowed: true, allowed: true,
text: "inital burst", text: "initial burst",
}) })
} }

View File

@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil) var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil { if err != nil {
return false return false
} }
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
fdset := fdSet{} fdset := fdSet{}
fdset.set(rw.fd) fdset.set(rw.fd)
fdset.set(closeFd) fdset.set(closeFd)
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil) var err error
for {
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
if err == nil || !RetryAfterError(err) {
break
}
}
if err != nil { if err != nil {
return false return false
} }

View File

@@ -11,6 +11,7 @@ import (
"net" "net"
"os" "os"
"syscall" "syscall"
"time"
"unsafe" "unsafe"
"golang.org/x/net/ipv6" "golang.org/x/net/ipv6"
@@ -42,6 +43,22 @@ type NativeTun struct {
var sockaddrCtlSize uintptr = 32 var sockaddrCtlSize uintptr = 32
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
for i := 0; i < 20; i++ {
iface, err = net.InterfaceByIndex(index)
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
time.Sleep(time.Duration(i) * time.Second / 3)
continue
}
}
}
return iface, err
}
return nil, err
}
func (tun *NativeTun) routineRouteListener(tunIfindex int) { func (tun *NativeTun) routineRouteListener(tunIfindex int) {
var ( var (
statusUp bool statusUp bool
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
continue continue
} }
iface, err := net.InterfaceByIndex(ifindex) iface, err := retryInterfaceByIndex(ifindex)
if err != nil { if err != nil {
tun.errors <- err tun.errors <- err
return return

View File

@@ -35,7 +35,7 @@ type NativeTun struct {
name string // name of interface name string // name of interface
errors chan error // async error handling errors chan error // async error handling
events chan Event // device related events events chan Event // device related events
nopi bool // the device was pased IFF_NO_PI nopi bool // the device was passed IFF_NO_PI
netlinkSock int netlinkSock int
netlinkCancel *rwcancel.RWCancel netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex hackListenerClosed sync.Mutex
@@ -85,7 +85,7 @@ func createNetlinkSocket() (int, error) {
} }
saddr := &unix.SockaddrNetlink{ saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK, Family: unix.AF_NETLINK,
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))), Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
} }
err = unix.Bind(sock, saddr) err = unix.Bind(sock, saddr)
if err != nil { if err != nil {

View File

@@ -9,6 +9,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"os" "os"
"sync"
"sync/atomic" "sync/atomic"
"time" "time"
"unsafe" "unsafe"
@@ -35,11 +36,12 @@ type NativeTun struct {
wt *wintun.Interface wt *wintun.Interface
handle windows.Handle handle windows.Handle
close bool close bool
rings wintun.RingDescriptor
events chan Event events chan Event
errors chan error errors chan error
forcedMTU int forcedMTU int
rate rateJuggler rate rateJuggler
rings *wintun.RingDescriptor
writeLock sync.Mutex
} }
const WintunPool = wintun.Pool("WireGuard") const WintunPool = wintun.Pool("WireGuard")
@@ -54,15 +56,15 @@ func nanotime() int64
// CreateTUN creates a Wintun interface with the given name. Should a Wintun // CreateTUN creates a Wintun interface with the given name. Should a Wintun
// interface with the same name exist, it is reused. // interface with the same name exist, it is reused.
// //
func CreateTUN(ifname string) (Device, error) { func CreateTUN(ifname string, mtu int) (Device, error) {
return CreateTUNWithRequestedGUID(ifname, nil) return CreateTUNWithRequestedGUID(ifname, nil, mtu)
} }
// //
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and // CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
// a requested GUID. Should a Wintun interface with the same name exist, it is reused. // a requested GUID. Should a Wintun interface with the same name exist, it is reused.
// //
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) { func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
var err error var err error
var wt *wintun.Interface var wt *wintun.Interface
@@ -72,12 +74,17 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
// If so, we delete it, in case it has weird residual configuration. // If so, we delete it, in case it has weird residual configuration.
_, err = wt.DeleteInterface() _, err = wt.DeleteInterface()
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err) return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
} }
} }
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID) wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
if err != nil { if err != nil {
return nil, fmt.Errorf("Unable to create Wintun interface: %v", err) return nil, fmt.Errorf("Error creating interface: %v", err)
}
forcedMTU := 1420
if mtu > 0 {
forcedMTU = mtu
} }
tun := &NativeTun{ tun := &NativeTun{
@@ -85,16 +92,16 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
handle: windows.InvalidHandle, handle: windows.InvalidHandle,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
forcedMTU: 1500, forcedMTU: forcedMTU,
} }
err = tun.rings.Init() tun.rings, err = wintun.NewRingDescriptor()
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error creating events: %v", err) return nil, fmt.Errorf("Error creating events: %v", err)
} }
tun.handle, err = tun.wt.Register(&tun.rings) tun.handle, err = tun.wt.Register(tun.rings)
if err != nil { if err != nil {
tun.Close() tun.Close()
return nil, fmt.Errorf("Error registering rings: %v", err) return nil, fmt.Errorf("Error registering rings: %v", err)
@@ -214,6 +221,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
tun.rate.update(uint64(packetSize)) tun.rate.update(uint64(packetSize))
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize) alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
tun.writeLock.Lock()
defer tun.writeLock.Unlock()
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head) buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
if buffHead >= wintun.PacketCapacity { if buffHead >= wintun.PacketCapacity {
return 0, os.ErrClosed return 0, os.ErrClosed
@@ -244,6 +254,11 @@ func (tun *NativeTun) LUID() uint64 {
return tun.wt.LUID() return tun.wt.LUID()
} }
// Version returns the version of the Wintun driver and NDIS system currently loaded.
func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
return tun.wt.Version()
}
func (rate *rateJuggler) update(packetLen uint64) { func (rate *rateJuggler) update(packetLen uint64) {
now := nanotime() now := nanotime()
total := atomic.AddUint64(&rate.nextByteCount, packetLen) total := atomic.AddUint64(&rate.nextByteCount, packetLen)

View File

@@ -5,4 +5,4 @@
package iphlpapi package iphlpapi
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go conversion_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go

View File

@@ -16,7 +16,6 @@ import (
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
"golang.org/x/text/unicode/norm" "golang.org/x/text/unicode/norm"
"golang.zx2c4.com/wireguard/ipc/winpipe"
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi" "golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
) )
@@ -32,13 +31,13 @@ func initializeNamespace() error {
if hasInitializedNamespace { if hasInitializedNamespace {
return nil return nil
} }
sd, err := winpipe.SddlToSecurityDescriptor("O:SYD:P(A;;GA;;;SY)") sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
if err != nil { if err != nil {
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err) return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
} }
wintunObjectSecurityAttributes = &windows.SecurityAttributes{ wintunObjectSecurityAttributes = &windows.SecurityAttributes{
Length: uint32(len(sd)), Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
SecurityDescriptor: uintptr(unsafe.Pointer(&sd[0])), SecurityDescriptor: sd,
} }
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid) sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
if err != nil { if err != nil {

View File

@@ -5,4 +5,4 @@
package namespaceapi package namespaceapi
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go namespaceapi_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go

View File

@@ -40,7 +40,7 @@ func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid) return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
} }
// PrivateNamespace represents a private namespace. Duh?! // PrivateNamespace represents a private namespace.
type PrivateNamespace windows.Handle type PrivateNamespace windows.Handle
// CreatePrivateNamespace creates a private namespace. // CreatePrivateNamespace creates a private namespace.

View File

@@ -5,4 +5,4 @@
package nci package nci
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go nci_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go

View File

@@ -5,4 +5,4 @@
package registry package registry
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go

View File

@@ -6,6 +6,7 @@
package wintun package wintun
import ( import (
"runtime"
"unsafe" "unsafe"
"golang.org/x/sys/windows" "golang.org/x/sys/windows"
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1) return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
} }
func (descriptor *RingDescriptor) Init() (err error) { func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
descriptor = new(RingDescriptor)
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
if err != nil {
return
}
defer func() {
if err != nil {
descriptor.free()
descriptor = nil
}
}()
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{})) descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Send.Ring = &Ring{} descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
return return
} }
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{})) descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
descriptor.Receive.Ring = &Ring{} descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil) descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
if err != nil { if err != nil {
windows.CloseHandle(descriptor.Send.TailMoved) windows.CloseHandle(descriptor.Send.TailMoved)
return return
} }
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
return return
} }
func (descriptor *RingDescriptor) free() {
if descriptor.Send.Ring != nil {
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
descriptor.Send.Ring = nil
descriptor.Receive.Ring = nil
}
}
func (descriptor *RingDescriptor) Close() { func (descriptor *RingDescriptor) Close() {
if descriptor.Send.TailMoved != 0 { if descriptor.Send.TailMoved != 0 {
windows.CloseHandle(descriptor.Send.TailMoved) windows.CloseHandle(descriptor.Send.TailMoved)

View File

@@ -5,4 +5,4 @@
package setupapi package setupapi
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsetupapi_windows.go setupapi_windows.go //go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go

View File

@@ -57,7 +57,7 @@ type DevInfoData struct {
_ uintptr _ uintptr
} }
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass). // DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
type DevInfoListDetailData struct { type DevInfoListDetailData struct {
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const. size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
ClassGUID windows.GUID ClassGUID windows.GUID

View File

@@ -40,9 +40,9 @@ const (
) )
// makeWintun creates a Wintun interface handle and populates it from the device's registry key. // makeWintun creates a Wintun interface handle and populates it from the device's registry key.
func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) { func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key. // Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE) key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
if err != nil { if err != nil {
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err) return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
} }
@@ -72,7 +72,7 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err) return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
} }
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData) instanceID, err := devInfo.DeviceInstanceID(devInfoData)
if err != nil { if err != nil {
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err) return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
} }
@@ -109,11 +109,11 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
}() }()
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err) return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
} }
defer devInfoList.Close() defer devInfo.Close()
// Windows requires each interface to have a different name. When // Windows requires each interface to have a different name. When
// enforcing this, Windows treats interface names case-insensitive. If an // enforcing this, Windows treats interface names case-insensitive. If an
@@ -123,7 +123,7 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
ifname = strings.ToLower(ifname) ifname = strings.ToLower(ifname)
for index := 0; ; index++ { for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index) devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -131,7 +131,16 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
continue continue
} }
wintun, err := makeWintun(devInfoList, deviceData, pool) // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil {
continue
}
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
continue continue
} }
@@ -145,14 +154,14 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
ifname3 := removeNumberedSuffix(ifname2) ifname3 := removeNumberedSuffix(ifname2)
if ifname == ifname2 || ifname == ifname3 { if ifname == ifname2 || ifname == ifname3 {
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil { if err != nil {
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
for index := 0; ; index++ { for index := 0; ; index++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -161,13 +170,13 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
} }
// Get driver info details. // Get driver info details.
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
isMember, err := pool.isMember(devInfoList, deviceData) isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -206,12 +215,12 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}() }()
// Create an empty device info set for network adapter device class. // Create an empty device info set for network adapter device class.
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "") devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err) err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
return return
} }
defer devInfoList.Close() defer devInfo.Close()
// Get the device class name from GUID. // Get the device class name from GUID.
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "") className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
@@ -222,43 +231,43 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// Create a new device info element and add it to the device info set. // Create a new device info element and add it to the device info set.
deviceTypeName := pool.deviceTypeName() deviceTypeName := pool.deviceTypeName()
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID) devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err) err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
return return
} }
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("Setting quiet installation failed: %v", err) err = fmt.Errorf("Setting quiet installation failed: %v", err)
return return
} }
// Set a device information element as the selected member of a device information set. // Set a device information element as the selected member of a device information set.
err = devInfoList.SetSelectedDevice(deviceData) err = devInfo.SetSelectedDevice(devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err) err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
return return
} }
// Set Plug&Play device hardware ID property. // Set Plug&Play device hardware ID property.
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID) err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err) err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
return return
} }
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err) err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
return return
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
driverDate := windows.Filetime{} driverDate := windows.Filetime{}
driverVersion := uint64(0) driverVersion := uint64(0)
for index := 0; ; index++ { // TODO: This loop takes ~600ms for index := 0; ; index++ { // TODO: This loop takes ~600ms
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -268,13 +277,13 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match. // Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
if driverData.IsNewer(driverDate, driverVersion) { if driverData.IsNewer(driverDate, driverVersion) {
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
if driverDetailData.IsCompatible(hardwareID) { if driverDetailData.IsCompatible(hardwareID) {
err := devInfoList.SetSelectedDriver(deviceData, driverData) err := devInfo.SetSelectedDriver(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
@@ -299,10 +308,10 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil { if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
// Call appropriate class installer. // Call appropriate class installer.
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil { if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData) rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
} }
} }
@@ -311,14 +320,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
}() }()
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
return return
} }
// Register device co-installers if any. (Ignore errors) // Register device co-installers if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData) devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
var netDevRegKey registry.Key var netDevRegKey registry.Key
const pollTimeout = time.Millisecond * 50 const pollTimeout = time.Millisecond * 50
@@ -326,7 +335,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
if i != 0 { if i != 0 {
time.Sleep(pollTimeout) time.Sleep(pollTimeout)
} }
netDevRegKey, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY) netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
if err == nil { if err == nil {
break break
} }
@@ -345,17 +354,17 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
} }
// Install interfaces if any. (Ignore errors) // Install interfaces if any. (Ignore errors)
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData) devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
// Install the device. // Install the device.
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
if err != nil { if err != nil {
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err) err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
return return
} }
rebootRequired = checkReboot(devInfoList, deviceData) rebootRequired = checkReboot(devInfo, devInfoData)
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_DEVICEDESC, deviceTypeName) err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
if err != nil { if err != nil {
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
return return
@@ -381,7 +390,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
} }
// Get network interface. // Get network interface.
wintun, err = makeWintun(devInfoList, deviceData, pool) wintun, err = makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
err = fmt.Errorf("makeWintun failed: %v", err) err = fmt.Errorf("makeWintun failed: %v", err)
return return
@@ -435,14 +444,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
// if the interface was not found. It returns a bool indicating whether // if the interface was not found. It returns a bool indicating whether
// a reboot is required. // a reboot is required.
func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) { func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
devInfoList, deviceData, err := wintun.deviceData() devInfo, devInfoData, err := wintun.devInfoData()
if err == windows.ERROR_OBJECT_NOT_FOUND { if err == windows.ERROR_OBJECT_NOT_FOUND {
return false, nil return false, nil
} }
if err != nil { if err != nil {
return false, err return false, err
} }
defer devInfoList.Close() defer devInfo.Close()
// Remove the device. // Remove the device.
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
@@ -451,18 +460,18 @@ func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
} }
// Set class installer parameters for DIF_REMOVE. // Set class installer parameters for DIF_REMOVE.
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err) return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
} }
// Call appropriate class installer. // Call appropriate class installer.
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil { if err != nil {
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err) return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
} }
return checkReboot(devInfoList, deviceData), nil return checkReboot(devInfo, devInfoData), nil
} }
// DeleteMatchingInterfaces deletes all Wintun interfaces, which match // DeleteMatchingInterfaces deletes all Wintun interfaces, which match
@@ -479,14 +488,14 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
windows.CloseHandle(mutex) windows.CloseHandle(mutex)
}() }()
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())} return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
} }
defer devInfoList.Close() defer devInfo.Close()
for i := 0; ; i++ { for i := 0; ; i++ {
deviceData, err := devInfoList.EnumDeviceInfo(i) devInfoData, err := devInfo.EnumDeviceInfo(i)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -494,22 +503,31 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
if err != nil { if err != nil {
continue continue
} }
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
continue
}
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
if err != nil {
continue
}
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
isWintun := false isWintun := false
for j := 0; ; j++ { for j := 0; ; j++ {
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j) driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
} }
continue continue
} }
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData) driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
if err != nil { if err != nil {
continue continue
} }
@@ -522,7 +540,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
isMember, err := pool.isMember(devInfoList, deviceData) isMember, err := pool.isMember(devInfo, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
@@ -531,7 +549,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
wintun, err := makeWintun(devInfoList, deviceData, pool) wintun, err := makeWintun(devInfo, devInfoData, pool)
if err != nil { if err != nil {
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err)) errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
continue continue
@@ -540,41 +558,41 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
continue continue
} }
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
inst := deviceData.DevInst inst := devInfoData.DevInst
removeDeviceParams := setupapi.RemoveDeviceParams{ removeDeviceParams := setupapi.RemoveDeviceParams{
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE), ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL, Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
} }
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
if err != nil { if err != nil {
errors = append(errors, err) errors = append(errors, err)
continue continue
} }
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData) rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
deviceInstancesDeleted = append(deviceInstancesDeleted, inst) deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
} }
return return
} }
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name. // isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (bool, error) { func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
deviceDescVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_DEVICEDESC) deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
if err != nil { if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err) return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
} }
deviceDesc, _ := deviceDescVal.(string) deviceDesc, _ := deviceDescVal.(string)
friendlyNameVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_FRIENDLYNAME) friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
if err != nil { if err != nil {
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err) return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
} }
@@ -585,8 +603,8 @@ func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupa
} }
// checkReboot checks device install parameters if a system reboot is required. // checkReboot checks device install parameters if a system reboot is required.
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool { func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil { if err != nil {
return false return false
} }
@@ -595,14 +613,14 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
} }
// setQuietInstall sets device install parameters for a quiet installation // setQuietInstall sets device install parameters for a quiet installation
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error { func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData) devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
if err != nil { if err != nil {
return err return err
} }
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams) return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
} }
// deviceTypeName returns pool-specific device type name. // deviceTypeName returns pool-specific device type name.
@@ -671,11 +689,39 @@ func (wintun *Interface) tcpipAdapterRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID) return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
} }
// deviceRegKeyName returns the device-level registry key name // deviceRegKeyName returns the device-level registry key name.
func (wintun *Interface) deviceRegKeyName() string { func (wintun *Interface) deviceRegKeyName() string {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID) return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
} }
// Version returns the version of the Wintun driver and NDIS system currently loaded.
func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE)
if err != nil {
return
}
defer key.Close()
driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion")
if err != nil {
return
}
driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion")
if err != nil {
return
}
ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion")
if err != nil {
return
}
ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion")
if err != nil {
return
}
driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor)
ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor)
return
}
// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name. // tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) { func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE) key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
@@ -693,18 +739,18 @@ func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
} }
// deviceData returns TUN device info list handle and interface device info // devInfoData returns TUN device info list handle and interface device info
// data. The device info list handle must be closed after use. In case the // data. The device info list handle must be closed after use. In case the
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned. // device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) { func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
// Create a list of network devices. // Create a list of network devices.
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "") devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
if err != nil { if err != nil {
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error()) return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
} }
for index := 0; ; index++ { for index := 0; ; index++ {
deviceData, err := devInfoList.EnumDeviceInfo(index) devInfoData, err := devInfo.EnumDeviceInfo(index)
if err != nil { if err != nil {
if err == windows.ERROR_NO_MORE_ITEMS { if err == windows.ERROR_NO_MORE_ITEMS {
break break
@@ -714,22 +760,22 @@ func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData,
// Get interface ID. // Get interface ID.
// TODO: Store some ID in the Wintun object such that this call isn't required. // TODO: Store some ID in the Wintun object such that this call isn't required.
wintun2, err := makeWintun(devInfoList, deviceData, wintun.pool) wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
if err != nil { if err != nil {
continue continue
} }
if wintun.cfgInstanceID == wintun2.cfgInstanceID { if wintun.cfgInstanceID == wintun2.cfgInstanceID {
err = setQuietInstall(devInfoList, deviceData) err = setQuietInstall(devInfo, devInfoData)
if err != nil { if err != nil {
devInfoList.Close() devInfo.Close()
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err) return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
} }
return devInfoList, deviceData, nil return devInfo, devInfoData, nil
} }
} }
devInfoList.Close() devInfo.Close()
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
} }