Compare commits
25 Commits
0.0.201909
...
0.0.202001
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
05b03c6750 | ||
|
|
caebdfe9d0 | ||
|
|
4fa2ea6a2d | ||
|
|
89dd065e53 | ||
|
|
ddfad453cf | ||
|
|
2b242f9393 | ||
|
|
4cdf805b29 | ||
|
|
f7d0edd2ec | ||
|
|
ffffbbcc8a | ||
|
|
47b02c618b | ||
|
|
fd23c66fcd | ||
|
|
ae492d1b35 | ||
|
|
95fbfccf60 | ||
|
|
c85e4a410f | ||
|
|
1b6c8ddbe8 | ||
|
|
0abb6b668c | ||
|
|
540d01e54a | ||
|
|
f2ea85e9f9 | ||
|
|
222f0f8000 | ||
|
|
1f146a5e7a | ||
|
|
f2501aa6c8 | ||
|
|
cb8d01f58a | ||
|
|
01f8ef4e84 | ||
|
|
70f6c42556 | ||
|
|
bb0b2514c0 |
2
Makefile
2
Makefile
@@ -10,7 +10,7 @@ MAKEFLAGS += --no-print-directory
|
|||||||
generate-version-and-build:
|
generate-version-and-build:
|
||||||
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
@export GIT_CEILING_DIRECTORIES="$(realpath $(CURDIR)/..)" && \
|
||||||
tag="$$(git describe --dirty 2>/dev/null)" && \
|
tag="$$(git describe --dirty 2>/dev/null)" && \
|
||||||
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$$tag")" && \
|
ver="$$(printf 'package device\nconst WireGuardGoVersion = "%s"\n' "$${tag#v}")" && \
|
||||||
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
[ "$$(cat device/version.go 2>/dev/null)" != "$$ver" ] && \
|
||||||
echo "$$ver" > device/version.go && \
|
echo "$$ver" > device/version.go && \
|
||||||
git update-index --assume-unchanged device/version.go || true
|
git update-index --assume-unchanged device/version.go || true
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ To run wireguard-go without forking to the background, pass `-f` or `--foregroun
|
|||||||
$ wireguard-go -f wg0
|
$ wireguard-go -f wg0
|
||||||
```
|
```
|
||||||
|
|
||||||
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/WireGuard/about/src/tools/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
When an interface is running, you may use [`wg(8)`](https://git.zx2c4.com/wireguard-tools/about/src/man/wg.8) to configure it, as well as the usual `ip(8)` and `ifconfig(8)` commands.
|
||||||
|
|
||||||
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ const (
|
|||||||
sockoptIPV6_UNICAST_IF = 31
|
sockoptIPV6_UNICAST_IF = 31
|
||||||
)
|
)
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||||
bytes := make([]byte, 4)
|
bytes := make([]byte, 4)
|
||||||
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
||||||
@@ -41,10 +41,11 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
device.net.bind.(*nativeBind).blackhole4 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -58,5 +59,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
device.net.bind.(*nativeBind).blackhole6 = blackhole
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,8 @@ import (
|
|||||||
type nativeBind struct {
|
type nativeBind struct {
|
||||||
ipv4 *net.UDPConn
|
ipv4 *net.UDPConn
|
||||||
ipv6 *net.UDPConn
|
ipv6 *net.UDPConn
|
||||||
|
blackhole4 bool
|
||||||
|
blackhole6 bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type NativeEndpoint net.UDPAddr
|
type NativeEndpoint net.UDPAddr
|
||||||
@@ -159,11 +161,17 @@ func (bind *nativeBind) Send(buff []byte, endpoint Endpoint) error {
|
|||||||
if bind.ipv4 == nil {
|
if bind.ipv4 == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
if bind.blackhole4 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
} else {
|
} else {
|
||||||
if bind.ipv6 == nil {
|
if bind.ipv6 == nil {
|
||||||
return syscall.EAFNOSUPPORT
|
return syscall.EAFNOSUPPORT
|
||||||
}
|
}
|
||||||
|
if bind.blackhole6 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
|
||||||
}
|
}
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
* This implements userspace semantics of "sticky sockets", modeled after
|
* This implements userspace semantics of "sticky sockets", modeled after
|
||||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||||
* of the sticky-sockets.c example code:
|
* of the sticky-sockets.c example code:
|
||||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
|
||||||
*
|
*
|
||||||
* Currently there is no way to achieve this within the net package:
|
* Currently there is no way to achieve this within the net package:
|
||||||
* See e.g. https://github.com/golang/go/issues/17930
|
* See e.g. https://github.com/golang/go/issues/17930
|
||||||
@@ -43,6 +43,7 @@ type IPv6Source struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type NativeEndpoint struct {
|
type NativeEndpoint struct {
|
||||||
|
sync.Mutex
|
||||||
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
|
||||||
src [unsafe.Sizeof(IPv6Source{})]byte
|
src [unsafe.Sizeof(IPv6Source{})]byte
|
||||||
isV6 bool
|
isV6 bool
|
||||||
@@ -145,7 +146,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|||||||
|
|
||||||
go bind.routineRouteListener(device)
|
go bind.routineRouteListener(device)
|
||||||
|
|
||||||
// attempt ipv6 bind, update port if succesful
|
// attempt ipv6 bind, update port if successful
|
||||||
|
|
||||||
bind.sock6, newPort, err = create6(port)
|
bind.sock6, newPort, err = create6(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -157,7 +158,7 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
|||||||
port = newPort
|
port = newPort
|
||||||
}
|
}
|
||||||
|
|
||||||
// attempt ipv4 bind, update port if succesful
|
// attempt ipv4 bind, update port if successful
|
||||||
|
|
||||||
bind.sock4, newPort, err = create4(port)
|
bind.sock4, newPort, err = create4(port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -482,7 +483,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end.Lock()
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
end.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -493,7 +496,9 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
if err == unix.EINVAL {
|
if err == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
cmsg.pktinfo = unix.Inet4Pktinfo{}
|
||||||
|
end.Lock()
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
|
||||||
|
end.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -522,7 +527,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
cmsg.pktinfo.Ifindex = 0
|
cmsg.pktinfo.Ifindex = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
end.Lock()
|
||||||
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
end.Unlock()
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -533,7 +540,9 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
if err == unix.EINVAL {
|
if err == unix.EINVAL {
|
||||||
end.ClearSrc()
|
end.ClearSrc()
|
||||||
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
cmsg.pktinfo = unix.Inet6Pktinfo{}
|
||||||
|
end.Lock()
|
||||||
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
|
||||||
|
end.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
return err
|
||||||
@@ -541,7 +550,7 @@ func send6(sock int, end *NativeEndpoint, buff []byte) error {
|
|||||||
|
|
||||||
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// construct message header
|
||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
@@ -573,7 +582,7 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
|||||||
|
|
||||||
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||||
|
|
||||||
// contruct message header
|
// construct message header
|
||||||
|
|
||||||
var cmsg struct {
|
var cmsg struct {
|
||||||
cmsghdr unix.Cmsghdr
|
cmsghdr unix.Cmsghdr
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
/* Specification constants */
|
/* Specification constants */
|
||||||
|
|
||||||
const (
|
const (
|
||||||
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
|
RekeyAfterMessages = (1 << 60)
|
||||||
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
||||||
RekeyAfterTime = time.Second * 120
|
RekeyAfterTime = time.Second * 120
|
||||||
RekeyAttemptTime = time.Second * 90
|
RekeyAttemptTime = time.Second * 90
|
||||||
|
|||||||
@@ -5,54 +5,212 @@
|
|||||||
|
|
||||||
package device
|
package device
|
||||||
|
|
||||||
/* Create two device instances and simulate full WireGuard interaction
|
|
||||||
* without network dependencies
|
|
||||||
*/
|
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/binary"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestDevice(t *testing.T) {
|
func TestTwoDevicePing(t *testing.T) {
|
||||||
|
// TODO(crawshaw): pick unused ports on localhost
|
||||||
// prepare tun devices for generating traffic
|
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
|
||||||
|
listen_port=53511
|
||||||
tun1 := newDummyTUN("tun1")
|
replace_peers=true
|
||||||
tun2 := newDummyTUN("tun2")
|
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
|
||||||
|
protocol_version=1
|
||||||
_ = tun1
|
replace_allowed_ips=true
|
||||||
_ = tun2
|
allowed_ip=1.0.0.2/32
|
||||||
|
endpoint=127.0.0.1:53512`
|
||||||
// prepare endpoints
|
tun1 := NewChannelTUN()
|
||||||
|
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
|
||||||
end1, err := CreateDummyEndpoint()
|
dev1.Up()
|
||||||
if err != nil {
|
defer dev1.Close()
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
if err := dev1.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg1))); err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
end2, err := CreateDummyEndpoint()
|
|
||||||
if err != nil {
|
|
||||||
t.Error("failed to create endpoint:", err.Error())
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = end1
|
|
||||||
_ = end2
|
|
||||||
|
|
||||||
// create binds
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func randDevice(t *testing.T) *Device {
|
|
||||||
sk, err := newPrivateKey()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
tun := newDummyTUN("dummy")
|
|
||||||
logger := NewLogger(LogLevelError, "")
|
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
|
||||||
device := NewDevice(tun, logger)
|
listen_port=53512
|
||||||
device.SetPrivateKey(sk)
|
replace_peers=true
|
||||||
return device
|
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
|
||||||
|
protocol_version=1
|
||||||
|
replace_allowed_ips=true
|
||||||
|
allowed_ip=1.0.0.1/32
|
||||||
|
endpoint=127.0.0.1:53511`
|
||||||
|
tun2 := NewChannelTUN()
|
||||||
|
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
|
||||||
|
dev2.Up()
|
||||||
|
defer dev2.Close()
|
||||||
|
if err := dev2.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg2))); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||||
|
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
||||||
|
tun2.Outbound <- msg2to1
|
||||||
|
select {
|
||||||
|
case msgRecv := <-tun1.Inbound:
|
||||||
|
if !bytes.Equal(msg2to1, msgRecv) {
|
||||||
|
t.Error("ping did not transit correctly")
|
||||||
|
}
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Error("ping did not transit")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||||
|
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
||||||
|
tun1.Outbound <- msg1to2
|
||||||
|
select {
|
||||||
|
case msgRecv := <-tun2.Inbound:
|
||||||
|
if !bytes.Equal(msg1to2, msgRecv) {
|
||||||
|
t.Error("return ping did not transit correctly")
|
||||||
|
}
|
||||||
|
case <-time.After(300 * time.Millisecond):
|
||||||
|
t.Error("return ping did not transit")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func ping(dst, src net.IP) []byte {
|
||||||
|
localPort := uint16(1337)
|
||||||
|
seq := uint16(0)
|
||||||
|
|
||||||
|
payload := make([]byte, 4)
|
||||||
|
binary.BigEndian.PutUint16(payload[0:], localPort)
|
||||||
|
binary.BigEndian.PutUint16(payload[2:], seq)
|
||||||
|
|
||||||
|
return genICMPv4(payload, dst, src)
|
||||||
|
}
|
||||||
|
|
||||||
|
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
||||||
|
func checksum(buf []byte, initial uint16) uint16 {
|
||||||
|
v := uint32(initial)
|
||||||
|
for i := 0; i < len(buf)-1; i += 2 {
|
||||||
|
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||||
|
}
|
||||||
|
if len(buf)%2 == 1 {
|
||||||
|
v += uint32(buf[len(buf)-1]) << 8
|
||||||
|
}
|
||||||
|
for v > 0xffff {
|
||||||
|
v = (v >> 16) + (v & 0xffff)
|
||||||
|
}
|
||||||
|
return ^uint16(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
||||||
|
const (
|
||||||
|
icmpv4ProtocolNumber = 1
|
||||||
|
icmpv4Echo = 8
|
||||||
|
icmpv4ChecksumOffset = 2
|
||||||
|
icmpv4Size = 8
|
||||||
|
ipv4Size = 20
|
||||||
|
ipv4TotalLenOffset = 2
|
||||||
|
ipv4ChecksumOffset = 10
|
||||||
|
ttl = 65
|
||||||
|
)
|
||||||
|
|
||||||
|
hdr := make([]byte, ipv4Size+icmpv4Size)
|
||||||
|
|
||||||
|
ip := hdr[0:ipv4Size]
|
||||||
|
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc792
|
||||||
|
icmpv4[0] = icmpv4Echo // type
|
||||||
|
icmpv4[1] = 0 // code
|
||||||
|
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
||||||
|
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
// https://tools.ietf.org/html/rfc760 section 3.1
|
||||||
|
length := uint16(len(hdr) + len(payload))
|
||||||
|
ip[0] = (4 << 4) | (ipv4Size / 4)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||||
|
ip[8] = ttl
|
||||||
|
ip[9] = icmpv4ProtocolNumber
|
||||||
|
copy(ip[12:], src.To4())
|
||||||
|
copy(ip[16:], dst.To4())
|
||||||
|
chksum = ^checksum(ip[:], 0)
|
||||||
|
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||||
|
|
||||||
|
var v []byte
|
||||||
|
v = append(v, hdr...)
|
||||||
|
v = append(v, payload...)
|
||||||
|
return []byte(v)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
||||||
|
type ChannelTUN struct {
|
||||||
|
Inbound chan []byte // incoming packets, closed on TUN close
|
||||||
|
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||||
|
|
||||||
|
closed chan struct{}
|
||||||
|
events chan tun.Event
|
||||||
|
tun chTun
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewChannelTUN() *ChannelTUN {
|
||||||
|
c := &ChannelTUN{
|
||||||
|
Inbound: make(chan []byte),
|
||||||
|
Outbound: make(chan []byte),
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
events: make(chan tun.Event, 1),
|
||||||
|
}
|
||||||
|
c.tun.c = c
|
||||||
|
c.events <- tun.EventUp
|
||||||
|
return c
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *ChannelTUN) TUN() tun.Device {
|
||||||
|
return &c.tun
|
||||||
|
}
|
||||||
|
|
||||||
|
type chTun struct {
|
||||||
|
c *ChannelTUN
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *chTun) File() *os.File { return nil }
|
||||||
|
|
||||||
|
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case msg := <-t.c.Outbound:
|
||||||
|
return copy(data[offset:], msg), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write is called by the wireguard device to deliver a packet for routing.
|
||||||
|
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||||
|
if offset == -1 {
|
||||||
|
close(t.c.closed)
|
||||||
|
close(t.c.events)
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
msg := make([]byte, len(data)-offset)
|
||||||
|
copy(msg, data[offset:])
|
||||||
|
select {
|
||||||
|
case <-t.c.closed:
|
||||||
|
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||||
|
case t.c.Inbound <- msg:
|
||||||
|
return len(data) - offset, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *chTun) Flush() error { return nil }
|
||||||
|
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||||
|
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||||
|
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
||||||
|
func (t *chTun) Close() error {
|
||||||
|
t.Write(nil, -1)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func assertNil(t *testing.T, err error) {
|
func assertNil(t *testing.T, err error) {
|
||||||
@@ -66,3 +224,15 @@ func assertEqual(t *testing.T, a, b []byte) {
|
|||||||
t.Fatal(a, "!=", b)
|
t.Fatal(a, "!=", b)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func randDevice(t *testing.T) *Device {
|
||||||
|
sk, err := newPrivateKey()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tun := newDummyTUN("dummy")
|
||||||
|
logger := NewLogger(LogLevelError, "")
|
||||||
|
device := NewDevice(tun, logger)
|
||||||
|
device.SetPrivateKey(sk)
|
||||||
|
return device
|
||||||
|
}
|
||||||
|
|||||||
@@ -39,13 +39,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
MessageInitiationSize = 148 // size of handshake initation message
|
MessageInitiationSize = 148 // size of handshake initiation message
|
||||||
MessageResponseSize = 92 // size of response message
|
MessageResponseSize = 92 // size of response message
|
||||||
MessageCookieReplySize = 64 // size of cookie reply message
|
MessageCookieReplySize = 64 // size of cookie reply message
|
||||||
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
|
MessageTransportHeaderSize = 16 // size of data preceding content in transport message
|
||||||
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
|
||||||
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
MessageKeepaliveSize = MessageTransportSize // size of keepalive
|
||||||
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
|
MessageHandshakeSize = MessageInitiationSize // size of largest handshake related message
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -315,8 +315,13 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
|||||||
handshake.chainKey = chainKey
|
handshake.chainKey = chainKey
|
||||||
handshake.remoteIndex = msg.Sender
|
handshake.remoteIndex = msg.Sender
|
||||||
handshake.remoteEphemeral = msg.Ephemeral
|
handshake.remoteEphemeral = msg.Ephemeral
|
||||||
|
if timestamp.After(handshake.lastTimestamp) {
|
||||||
handshake.lastTimestamp = timestamp
|
handshake.lastTimestamp = timestamp
|
||||||
handshake.lastInitiationConsumption = time.Now()
|
}
|
||||||
|
now := time.Now()
|
||||||
|
if now.After(handshake.lastInitiationConsumption) {
|
||||||
|
handshake.lastInitiationConsumption = now
|
||||||
|
}
|
||||||
handshake.state = HandshakeInitiationConsumed
|
handshake.state = HandshakeInitiationConsumed
|
||||||
|
|
||||||
handshake.mutex.Unlock()
|
handshake.mutex.Unlock()
|
||||||
|
|||||||
@@ -220,10 +220,7 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement)
|
|||||||
writer := bytes.NewBuffer(buff[:0])
|
writer := bytes.NewBuffer(buff[:0])
|
||||||
binary.Write(writer, binary.LittleEndian, reply)
|
binary.Write(writer, binary.LittleEndian, reply)
|
||||||
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
|
||||||
if err != nil {
|
return nil
|
||||||
device.log.Error.Println("Failed to send cookie reply:", err)
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (peer *Peer) keepKeyFreshSending() {
|
func (peer *Peer) keepKeyFreshSending() {
|
||||||
|
|||||||
@@ -113,6 +113,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
var peer *Peer
|
var peer *Peer
|
||||||
|
|
||||||
dummy := false
|
dummy := false
|
||||||
|
createdNewPeer := false
|
||||||
deviceConfig := true
|
deviceConfig := true
|
||||||
|
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
@@ -237,7 +238,8 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
peer = device.LookupPeer(publicKey)
|
peer = device.LookupPeer(publicKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if peer == nil {
|
createdNewPeer = peer == nil
|
||||||
|
if createdNewPeer {
|
||||||
peer, err = device.NewPeer(publicKey)
|
peer, err = device.NewPeer(publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logError.Println("Failed to create new peer:", err)
|
logError.Println("Failed to create new peer:", err)
|
||||||
@@ -251,6 +253,20 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case "update_only":
|
||||||
|
|
||||||
|
// allow disabling of creation
|
||||||
|
|
||||||
|
if value != "true" {
|
||||||
|
logError.Println("Failed to set update only, invalid value:", value)
|
||||||
|
return &IPCError{ipc.IpcErrorInvalid}
|
||||||
|
}
|
||||||
|
if createdNewPeer && !dummy {
|
||||||
|
device.RemovePeer(peer.handshake.remoteStatic)
|
||||||
|
peer = &Peer{}
|
||||||
|
dummy = true
|
||||||
|
}
|
||||||
|
|
||||||
case "remove":
|
case "remove":
|
||||||
|
|
||||||
// remove currently selected peer from device
|
// remove currently selected peer from device
|
||||||
|
|||||||
@@ -1,3 +1,3 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
const WireGuardGoVersion = "0.0.20190908"
|
const WireGuardGoVersion = "0.0.20200121"
|
||||||
|
|||||||
6
go.mod
6
go.mod
@@ -3,8 +3,8 @@ module golang.zx2c4.com/wireguard
|
|||||||
go 1.12
|
go 1.12
|
||||||
|
|
||||||
require (
|
require (
|
||||||
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472
|
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
|
||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297
|
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
|
||||||
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad
|
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
|
||||||
golang.org/x/text v0.3.2
|
golang.org/x/text v0.3.2
|
||||||
)
|
)
|
||||||
|
|||||||
12
go.sum
12
go.sum
@@ -1,13 +1,13 @@
|
|||||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||||
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472 h1:Gv7RPwsi3eZ2Fgewe3CBsuOebPwO27PoXzRpJPsvSSM=
|
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
|
||||||
golang.org/x/crypto v0.0.0-20190829043050-9756ffdc2472/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297 h1:k7pJ2yAPLPgbskkFdhRCsA77k2fySZ1zf2zCjvQCiIM=
|
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
|
||||||
golang.org/x/net v0.0.0-20190827160401-ba9fcec4b297/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad h1:cCejgArrk10gX6kFqjWeLwXD7aVMqWoRpyUCaaJSggc=
|
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
|
||||||
golang.org/x/sys v0.0.0-20190830023255-19e00faab6ad/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||||
|
|||||||
@@ -8,6 +8,8 @@ package ipc
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,8 +49,16 @@ func (l *UAPIListener) Addr() net.Addr {
|
|||||||
return l.listener.Addr()
|
return l.listener.Addr()
|
||||||
}
|
}
|
||||||
|
|
||||||
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
var UAPISecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
var UAPISecurityDescriptor = "O:SYD:P(A;;GA;;;SY)"
|
|
||||||
|
func init() {
|
||||||
|
var err error
|
||||||
|
/* SDDL_DEVOBJ_SYS_ALL from the WDK */
|
||||||
|
UAPISecurityDescriptor, err = windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func UAPIListen(name string) (net.Listener, error) {
|
func UAPIListen(name string) (net.Listener, error) {
|
||||||
config := winpipe.PipeConfig{
|
config := winpipe.PipeConfig{
|
||||||
|
|||||||
@@ -13,15 +13,16 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
//sys cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) = CancelIoEx
|
//sys cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) = CancelIoEx
|
||||||
//sys createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) = CreateIoCompletionPort
|
//sys createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) = CreateIoCompletionPort
|
||||||
//sys getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
//sys getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) = GetQueuedCompletionStatus
|
||||||
//sys setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
//sys setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) = SetFileCompletionNotificationModes
|
||||||
//sys wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
//sys wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) = ws2_32.WSAGetOverlappedResult
|
||||||
|
|
||||||
type atomicBool int32
|
type atomicBool int32
|
||||||
|
|
||||||
@@ -55,7 +56,7 @@ func (e *timeoutError) Temporary() bool { return true }
|
|||||||
type timeoutChan chan struct{}
|
type timeoutChan chan struct{}
|
||||||
|
|
||||||
var ioInitOnce sync.Once
|
var ioInitOnce sync.Once
|
||||||
var ioCompletionPort syscall.Handle
|
var ioCompletionPort windows.Handle
|
||||||
|
|
||||||
// ioResult contains the result of an asynchronous IO operation
|
// ioResult contains the result of an asynchronous IO operation
|
||||||
type ioResult struct {
|
type ioResult struct {
|
||||||
@@ -65,12 +66,12 @@ type ioResult struct {
|
|||||||
|
|
||||||
// ioOperation represents an outstanding asynchronous Win32 IO
|
// ioOperation represents an outstanding asynchronous Win32 IO
|
||||||
type ioOperation struct {
|
type ioOperation struct {
|
||||||
o syscall.Overlapped
|
o windows.Overlapped
|
||||||
ch chan ioResult
|
ch chan ioResult
|
||||||
}
|
}
|
||||||
|
|
||||||
func initIo() {
|
func initIo() {
|
||||||
h, err := createIoCompletionPort(syscall.InvalidHandle, 0, 0, 0xffffffff)
|
h, err := createIoCompletionPort(windows.InvalidHandle, 0, 0, 0xffffffff)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -81,7 +82,7 @@ func initIo() {
|
|||||||
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
// win32File implements Reader, Writer, and Closer on a Win32 handle without blocking in a syscall.
|
||||||
// It takes ownership of this handle and will close it if it is garbage collected.
|
// It takes ownership of this handle and will close it if it is garbage collected.
|
||||||
type win32File struct {
|
type win32File struct {
|
||||||
handle syscall.Handle
|
handle windows.Handle
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
wgLock sync.RWMutex
|
wgLock sync.RWMutex
|
||||||
closing atomicBool
|
closing atomicBool
|
||||||
@@ -99,7 +100,7 @@ type deadlineHandler struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// makeWin32File makes a new win32File from an existing file handle
|
// makeWin32File makes a new win32File from an existing file handle
|
||||||
func makeWin32File(h syscall.Handle) (*win32File, error) {
|
func makeWin32File(h windows.Handle) (*win32File, error) {
|
||||||
f := &win32File{handle: h}
|
f := &win32File{handle: h}
|
||||||
ioInitOnce.Do(initIo)
|
ioInitOnce.Do(initIo)
|
||||||
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
_, err := createIoCompletionPort(h, ioCompletionPort, 0, 0xffffffff)
|
||||||
@@ -115,7 +116,7 @@ func makeWin32File(h syscall.Handle) (*win32File, error) {
|
|||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func MakeOpenFile(h syscall.Handle) (io.ReadWriteCloser, error) {
|
func MakeOpenFile(h windows.Handle) (io.ReadWriteCloser, error) {
|
||||||
return makeWin32File(h)
|
return makeWin32File(h)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,7 +130,7 @@ func (f *win32File) closeHandle() {
|
|||||||
cancelIoEx(f.handle, nil)
|
cancelIoEx(f.handle, nil)
|
||||||
f.wg.Wait()
|
f.wg.Wait()
|
||||||
// at this point, no new IO can start
|
// at this point, no new IO can start
|
||||||
syscall.Close(f.handle)
|
windows.Close(f.handle)
|
||||||
f.handle = 0
|
f.handle = 0
|
||||||
} else {
|
} else {
|
||||||
f.wgLock.Unlock()
|
f.wgLock.Unlock()
|
||||||
@@ -158,12 +159,12 @@ func (f *win32File) prepareIo() (*ioOperation, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ioCompletionProcessor processes completed async IOs forever
|
// ioCompletionProcessor processes completed async IOs forever
|
||||||
func ioCompletionProcessor(h syscall.Handle) {
|
func ioCompletionProcessor(h windows.Handle) {
|
||||||
for {
|
for {
|
||||||
var bytes uint32
|
var bytes uint32
|
||||||
var key uintptr
|
var key uintptr
|
||||||
var op *ioOperation
|
var op *ioOperation
|
||||||
err := getQueuedCompletionStatus(h, &bytes, &key, &op, syscall.INFINITE)
|
err := getQueuedCompletionStatus(h, &bytes, &key, &op, windows.INFINITE)
|
||||||
if op == nil {
|
if op == nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
@@ -174,7 +175,7 @@ func ioCompletionProcessor(h syscall.Handle) {
|
|||||||
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
// asyncIo processes the return value from ReadFile or WriteFile, blocking until
|
||||||
// the operation has actually completed.
|
// the operation has actually completed.
|
||||||
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, err error) (int, error) {
|
||||||
if err != syscall.ERROR_IO_PENDING {
|
if err != windows.ERROR_IO_PENDING {
|
||||||
return int(bytes), err
|
return int(bytes), err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -193,7 +194,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
|
|||||||
select {
|
select {
|
||||||
case r = <-c.ch:
|
case r = <-c.ch:
|
||||||
err = r.err
|
err = r.err
|
||||||
if err == syscall.ERROR_OPERATION_ABORTED {
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
if f.closing.isSet() {
|
if f.closing.isSet() {
|
||||||
err = ErrFileClosed
|
err = ErrFileClosed
|
||||||
}
|
}
|
||||||
@@ -206,7 +207,7 @@ func (f *win32File) asyncIo(c *ioOperation, d *deadlineHandler, bytes uint32, er
|
|||||||
cancelIoEx(f.handle, &c.o)
|
cancelIoEx(f.handle, &c.o)
|
||||||
r = <-c.ch
|
r = <-c.ch
|
||||||
err = r.err
|
err = r.err
|
||||||
if err == syscall.ERROR_OPERATION_ABORTED {
|
if err == windows.ERROR_OPERATION_ABORTED {
|
||||||
err = ErrTimeout
|
err = ErrTimeout
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -231,14 +232,14 @@ func (f *win32File) Read(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var bytes uint32
|
var bytes uint32
|
||||||
err = syscall.ReadFile(f.handle, b, &bytes, &c.o)
|
err = windows.ReadFile(f.handle, b, &bytes, &c.o)
|
||||||
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
n, err := f.asyncIo(c, &f.readDeadline, bytes, err)
|
||||||
runtime.KeepAlive(b)
|
runtime.KeepAlive(b)
|
||||||
|
|
||||||
// Handle EOF conditions.
|
// Handle EOF conditions.
|
||||||
if err == nil && n == 0 && len(b) != 0 {
|
if err == nil && n == 0 && len(b) != 0 {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
} else if err == syscall.ERROR_BROKEN_PIPE {
|
} else if err == windows.ERROR_BROKEN_PIPE {
|
||||||
return 0, io.EOF
|
return 0, io.EOF
|
||||||
} else {
|
} else {
|
||||||
return n, err
|
return n, err
|
||||||
@@ -258,7 +259,7 @@ func (f *win32File) Write(b []byte) (int, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var bytes uint32
|
var bytes uint32
|
||||||
err = syscall.WriteFile(f.handle, b, &bytes, &c.o)
|
err = windows.WriteFile(f.handle, b, &bytes, &c.o)
|
||||||
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
n, err := f.asyncIo(c, &f.writeDeadline, bytes, err)
|
||||||
runtime.KeepAlive(b)
|
runtime.KeepAlive(b)
|
||||||
return n, err
|
return n, err
|
||||||
@@ -273,7 +274,7 @@ func (f *win32File) SetWriteDeadline(deadline time.Time) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *win32File) Flush() error {
|
func (f *win32File) Flush() error {
|
||||||
return syscall.FlushFileBuffers(f.handle)
|
return windows.FlushFileBuffers(f.handle)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *win32File) Fd() uintptr {
|
func (f *win32File) Fd() uintptr {
|
||||||
|
|||||||
@@ -6,4 +6,4 @@
|
|||||||
|
|
||||||
package winpipe
|
package winpipe
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go pipe.go sd.go file.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go pipe.go file.go
|
||||||
|
|||||||
@@ -16,18 +16,19 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"syscall"
|
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
//sys connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) = ConnectNamedPipe
|
//sys connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) = ConnectNamedPipe
|
||||||
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateNamedPipeW
|
//sys createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateNamedPipeW
|
||||||
//sys createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) [failretval==syscall.InvalidHandle] = CreateFileW
|
//sys createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) [failretval==windows.InvalidHandle] = CreateFileW
|
||||||
//sys getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
//sys getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) = GetNamedPipeInfo
|
||||||
//sys getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
//sys getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
|
||||||
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
//sys localAlloc(uFlags uint32, length uint32) (ptr uintptr) = LocalAlloc
|
||||||
//sys ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
//sys ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) = ntdll.NtCreateNamedPipeFile
|
||||||
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
//sys rtlNtStatusToDosError(status ntstatus) (winerr error) = ntdll.RtlNtStatusToDosErrorNoTeb
|
||||||
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
//sys rtlDosPathNameToNtPathName(name *uint16, ntName *unicodeString, filePart uintptr, reserved uintptr) (status ntstatus) = ntdll.RtlDosPathNameToNtPathName_U
|
||||||
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
//sys rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) = ntdll.RtlDefaultNpAcl
|
||||||
@@ -41,7 +42,7 @@ type objectAttributes struct {
|
|||||||
RootDirectory uintptr
|
RootDirectory uintptr
|
||||||
ObjectName *unicodeString
|
ObjectName *unicodeString
|
||||||
Attributes uintptr
|
Attributes uintptr
|
||||||
SecurityDescriptor *securityDescriptor
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
SecurityQoS uintptr
|
SecurityQoS uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -51,16 +52,6 @@ type unicodeString struct {
|
|||||||
Buffer uintptr
|
Buffer uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
type securityDescriptor struct {
|
|
||||||
Revision byte
|
|
||||||
Sbz1 byte
|
|
||||||
Control uint16
|
|
||||||
Owner uintptr
|
|
||||||
Group uintptr
|
|
||||||
Sacl uintptr
|
|
||||||
Dacl uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
type ntstatus int32
|
type ntstatus int32
|
||||||
|
|
||||||
func (status ntstatus) Err() error {
|
func (status ntstatus) Err() error {
|
||||||
@@ -71,11 +62,6 @@ func (status ntstatus) Err() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
cERROR_PIPE_BUSY = syscall.Errno(231)
|
|
||||||
cERROR_NO_DATA = syscall.Errno(232)
|
|
||||||
cERROR_PIPE_CONNECTED = syscall.Errno(535)
|
|
||||||
cERROR_SEM_TIMEOUT = syscall.Errno(121)
|
|
||||||
|
|
||||||
cSECURITY_SQOS_PRESENT = 0x100000
|
cSECURITY_SQOS_PRESENT = 0x100000
|
||||||
cSECURITY_ANONYMOUS = 0
|
cSECURITY_ANONYMOUS = 0
|
||||||
|
|
||||||
@@ -88,8 +74,6 @@ const (
|
|||||||
|
|
||||||
cFILE_PIPE_MESSAGE_TYPE = 1
|
cFILE_PIPE_MESSAGE_TYPE = 1
|
||||||
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
cFILE_PIPE_REJECT_REMOTE_CLIENTS = 2
|
||||||
|
|
||||||
cSE_DACL_PRESENT = 4
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -170,7 +154,7 @@ func (f *win32MessageBytePipe) Read(b []byte) (int, error) {
|
|||||||
// zero-byte message, ensure that all future Read() calls
|
// zero-byte message, ensure that all future Read() calls
|
||||||
// also return EOF.
|
// also return EOF.
|
||||||
f.readEOF = true
|
f.readEOF = true
|
||||||
} else if err == syscall.ERROR_MORE_DATA {
|
} else if err == windows.ERROR_MORE_DATA {
|
||||||
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
// ERROR_MORE_DATA indicates that the pipe's read mode is message mode
|
||||||
// and the message still has more bytes. Treat this as a success, since
|
// and the message still has more bytes. Treat this as a success, since
|
||||||
// this package presents all named pipes as byte streams.
|
// this package presents all named pipes as byte streams.
|
||||||
@@ -188,17 +172,17 @@ func (s pipeAddress) String() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
// tryDialPipe attempts to dial the pipe at `path` until `ctx` cancellation or timeout.
|
||||||
func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
func tryDialPipe(ctx context.Context, path *string) (windows.Handle, error) {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return syscall.Handle(0), ctx.Err()
|
return windows.Handle(0), ctx.Err()
|
||||||
default:
|
default:
|
||||||
h, err := createFile(*path, syscall.GENERIC_READ|syscall.GENERIC_WRITE, 0, nil, syscall.OPEN_EXISTING, syscall.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
h, err := createFile(*path, windows.GENERIC_READ|windows.GENERIC_WRITE, 0, nil, windows.OPEN_EXISTING, windows.FILE_FLAG_OVERLAPPED|cSECURITY_SQOS_PRESENT|cSECURITY_ANONYMOUS, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return h, nil
|
return h, nil
|
||||||
}
|
}
|
||||||
if err != cERROR_PIPE_BUSY {
|
if err != windows.ERROR_PIPE_BUSY {
|
||||||
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
return h, &os.PathError{Err: err, Op: "open", Path: *path}
|
||||||
}
|
}
|
||||||
// Wait 10 msec and try again. This is a rather simplistic
|
// Wait 10 msec and try again. This is a rather simplistic
|
||||||
@@ -211,7 +195,7 @@ func tryDialPipe(ctx context.Context, path *string) (syscall.Handle, error) {
|
|||||||
// DialPipe connects to a named pipe by path, timing out if the connection
|
// DialPipe connects to a named pipe by path, timing out if the connection
|
||||||
// takes longer than the specified duration. If timeout is nil, then we use
|
// takes longer than the specified duration. If timeout is nil, then we use
|
||||||
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
// a default timeout of 2 seconds. (We do not use WaitNamedPipe.)
|
||||||
func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (net.Conn, error) {
|
func DialPipe(path string, timeout *time.Duration, expectedOwner *windows.SID) (net.Conn, error) {
|
||||||
var absTimeout time.Time
|
var absTimeout time.Time
|
||||||
if timeout != nil {
|
if timeout != nil {
|
||||||
absTimeout = time.Now().Add(*timeout)
|
absTimeout = time.Now().Add(*timeout)
|
||||||
@@ -228,39 +212,41 @@ func DialPipe(path string, timeout *time.Duration, expectedOwner *syscall.SID) (
|
|||||||
|
|
||||||
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
// DialPipeContext attempts to connect to a named pipe by `path` until `ctx`
|
||||||
// cancellation or timeout.
|
// cancellation or timeout.
|
||||||
func DialPipeContext(ctx context.Context, path string, expectedOwner *syscall.SID) (net.Conn, error) {
|
func DialPipeContext(ctx context.Context, path string, expectedOwner *windows.SID) (net.Conn, error) {
|
||||||
var err error
|
var err error
|
||||||
var h syscall.Handle
|
var h windows.Handle
|
||||||
h, err = tryDialPipe(ctx, &path)
|
h, err = tryDialPipe(ctx, &path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if expectedOwner != nil {
|
if expectedOwner != nil {
|
||||||
var realOwner *syscall.SID
|
sd, err := windows.GetSecurityInfo(h, windows.SE_FILE_OBJECT, windows.OWNER_SECURITY_INFORMATION)
|
||||||
var realSd uintptr
|
|
||||||
err = getSecurityInfo(h, SE_FILE_OBJECT, OWNER_SECURITY_INFORMATION, &realOwner, nil, nil, nil, &realSd)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
defer localFree(realSd)
|
realOwner, _, err := sd.Owner()
|
||||||
if !equalSid(realOwner, expectedOwner) {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, syscall.ERROR_ACCESS_DENIED
|
return nil, err
|
||||||
|
}
|
||||||
|
if !realOwner.Equals(expectedOwner) {
|
||||||
|
windows.Close(h)
|
||||||
|
return nil, windows.ERROR_ACCESS_DENIED
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var flags uint32
|
var flags uint32
|
||||||
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
err = getNamedPipeInfo(h, &flags, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
f, err := makeWin32File(h)
|
f, err := makeWin32File(h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +266,7 @@ type acceptResponse struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type win32PipeListener struct {
|
type win32PipeListener struct {
|
||||||
firstHandle syscall.Handle
|
firstHandle windows.Handle
|
||||||
path string
|
path string
|
||||||
config PipeConfig
|
config PipeConfig
|
||||||
acceptCh chan (chan acceptResponse)
|
acceptCh chan (chan acceptResponse)
|
||||||
@@ -288,8 +274,8 @@ type win32PipeListener struct {
|
|||||||
doneCh chan int
|
doneCh chan int
|
||||||
}
|
}
|
||||||
|
|
||||||
func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (syscall.Handle, error) {
|
func makeServerPipeHandle(path string, sd *windows.SECURITY_DESCRIPTOR, c *PipeConfig, first bool) (windows.Handle, error) {
|
||||||
path16, err := syscall.UTF16FromString(path)
|
path16, err := windows.UTF16FromString(path)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
}
|
}
|
||||||
@@ -301,31 +287,32 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
|
|||||||
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
if err := rtlDosPathNameToNtPathName(&path16[0], &ntPath, 0, 0).Err(); err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
}
|
}
|
||||||
defer localFree(ntPath.Buffer)
|
defer windows.LocalFree(windows.Handle(ntPath.Buffer))
|
||||||
oa.ObjectName = &ntPath
|
oa.ObjectName = &ntPath
|
||||||
|
|
||||||
// The security descriptor is only needed for the first pipe.
|
// The security descriptor is only needed for the first pipe.
|
||||||
if first {
|
if first {
|
||||||
if sd != nil {
|
if sd != nil {
|
||||||
len := uint32(len(sd))
|
oa.SecurityDescriptor = sd
|
||||||
sdb := localAlloc(0, len)
|
|
||||||
defer localFree(sdb)
|
|
||||||
copy((*[0xffff]byte)(unsafe.Pointer(sdb))[:], sd)
|
|
||||||
oa.SecurityDescriptor = (*securityDescriptor)(unsafe.Pointer(sdb))
|
|
||||||
} else {
|
} else {
|
||||||
// Construct the default named pipe security descriptor.
|
// Construct the default named pipe security descriptor.
|
||||||
var dacl uintptr
|
var dacl uintptr
|
||||||
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
if err := rtlDefaultNpAcl(&dacl).Err(); err != nil {
|
||||||
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
return 0, fmt.Errorf("getting default named pipe ACL: %s", err)
|
||||||
}
|
}
|
||||||
defer localFree(dacl)
|
defer windows.LocalFree(windows.Handle(dacl))
|
||||||
|
sd, err := windows.NewSecurityDescriptor()
|
||||||
sdb := &securityDescriptor{
|
if err != nil {
|
||||||
Revision: 1,
|
return 0, fmt.Errorf("creating new security descriptor: %s", err)
|
||||||
Control: cSE_DACL_PRESENT,
|
|
||||||
Dacl: dacl,
|
|
||||||
}
|
}
|
||||||
oa.SecurityDescriptor = sdb
|
if err = sd.SetDACL((*windows.ACL)(unsafe.Pointer(dacl)), true, false); err != nil {
|
||||||
|
return 0, fmt.Errorf("assigning dacl: %s", err)
|
||||||
|
}
|
||||||
|
sd, err = sd.ToSelfRelative()
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("converting to self-relative: %s", err)
|
||||||
|
}
|
||||||
|
oa.SecurityDescriptor = sd
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -335,22 +322,22 @@ func makeServerPipeHandle(path string, sd []byte, c *PipeConfig, first bool) (sy
|
|||||||
}
|
}
|
||||||
|
|
||||||
disposition := uint32(cFILE_OPEN)
|
disposition := uint32(cFILE_OPEN)
|
||||||
access := uint32(syscall.GENERIC_READ | syscall.GENERIC_WRITE | syscall.SYNCHRONIZE)
|
access := uint32(windows.GENERIC_READ | windows.GENERIC_WRITE | windows.SYNCHRONIZE)
|
||||||
if first {
|
if first {
|
||||||
disposition = cFILE_CREATE
|
disposition = cFILE_CREATE
|
||||||
// By not asking for read or write access, the named pipe file system
|
// By not asking for read or write access, the named pipe file system
|
||||||
// will put this pipe into an initially disconnected state, blocking
|
// will put this pipe into an initially disconnected state, blocking
|
||||||
// client connections until the next call with first == false.
|
// client connections until the next call with first == false.
|
||||||
access = syscall.SYNCHRONIZE
|
access = windows.SYNCHRONIZE
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := int64(-50 * 10000) // 50ms
|
timeout := int64(-50 * 10000) // 50ms
|
||||||
|
|
||||||
var (
|
var (
|
||||||
h syscall.Handle
|
h windows.Handle
|
||||||
iosb ioStatusBlock
|
iosb ioStatusBlock
|
||||||
)
|
)
|
||||||
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, syscall.FILE_SHARE_READ|syscall.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
err = ntCreateNamedPipeFile(&h, access, &oa, &iosb, windows.FILE_SHARE_READ|windows.FILE_SHARE_WRITE, disposition, 0, typ, 0, 0, 0xffffffff, uint32(c.InputBufferSize), uint32(c.OutputBufferSize), &timeout).Err()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
return 0, &os.PathError{Op: "open", Path: path, Err: err}
|
||||||
}
|
}
|
||||||
@@ -366,7 +353,7 @@ func (l *win32PipeListener) makeServerPipe() (*win32File, error) {
|
|||||||
}
|
}
|
||||||
f, err := makeWin32File(h)
|
f, err := makeWin32File(h)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
syscall.Close(h)
|
windows.Close(h)
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return f, nil
|
return f, nil
|
||||||
@@ -417,7 +404,7 @@ func (l *win32PipeListener) listenerRoutine() {
|
|||||||
p, err = l.makeConnectedServerPipe()
|
p, err = l.makeConnectedServerPipe()
|
||||||
// If the connection was immediately closed by the client, try
|
// If the connection was immediately closed by the client, try
|
||||||
// again.
|
// again.
|
||||||
if err != cERROR_NO_DATA {
|
if err != windows.ERROR_NO_DATA {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -425,7 +412,7 @@ func (l *win32PipeListener) listenerRoutine() {
|
|||||||
closed = err == ErrPipeListenerClosed
|
closed = err == ErrPipeListenerClosed
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
syscall.Close(l.firstHandle)
|
windows.Close(l.firstHandle)
|
||||||
l.firstHandle = 0
|
l.firstHandle = 0
|
||||||
// Notify Close() and Accept() callers that the handle has been closed.
|
// Notify Close() and Accept() callers that the handle has been closed.
|
||||||
close(l.doneCh)
|
close(l.doneCh)
|
||||||
@@ -433,8 +420,8 @@ func (l *win32PipeListener) listenerRoutine() {
|
|||||||
|
|
||||||
// PipeConfig contain configuration for the pipe listener.
|
// PipeConfig contain configuration for the pipe listener.
|
||||||
type PipeConfig struct {
|
type PipeConfig struct {
|
||||||
// SecurityDescriptor contains a Windows security descriptor in SDDL format.
|
// SecurityDescriptor contains a Windows security descriptor.
|
||||||
SecurityDescriptor string
|
SecurityDescriptor *windows.SECURITY_DESCRIPTOR
|
||||||
|
|
||||||
// MessageMode determines whether the pipe is in byte or message mode. In either
|
// MessageMode determines whether the pipe is in byte or message mode. In either
|
||||||
// case the pipe is read in byte mode by default. The only practical difference in
|
// case the pipe is read in byte mode by default. The only practical difference in
|
||||||
@@ -454,20 +441,10 @@ type PipeConfig struct {
|
|||||||
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
// ListenPipe creates a listener on a Windows named pipe path, e.g. \\.\pipe\mypipe.
|
||||||
// The pipe must not already exist.
|
// The pipe must not already exist.
|
||||||
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
func ListenPipe(path string, c *PipeConfig) (net.Listener, error) {
|
||||||
var (
|
|
||||||
sd []byte
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
if c == nil {
|
if c == nil {
|
||||||
c = &PipeConfig{}
|
c = &PipeConfig{}
|
||||||
}
|
}
|
||||||
if c.SecurityDescriptor != "" {
|
h, err := makeServerPipeHandle(path, c.SecurityDescriptor, c, true)
|
||||||
sd, err = SddlToSecurityDescriptor(c.SecurityDescriptor)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
h, err := makeServerPipeHandle(path, sd, c, true)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -492,7 +469,7 @@ func connectPipe(p *win32File) error {
|
|||||||
|
|
||||||
err = connectNamedPipe(p.handle, &c.o)
|
err = connectNamedPipe(p.handle, &c.o)
|
||||||
_, err = p.asyncIo(c, nil, 0, err)
|
_, err = p.asyncIo(c, nil, 0, err)
|
||||||
if err != nil && err != cERROR_PIPE_CONNECTED {
|
if err != nil && err != windows.ERROR_PIPE_CONNECTED {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -1,36 +0,0 @@
|
|||||||
// +build windows
|
|
||||||
|
|
||||||
/* SPDX-License-Identifier: MIT
|
|
||||||
*
|
|
||||||
* Copyright (C) 2005 Microsoft
|
|
||||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package winpipe
|
|
||||||
|
|
||||||
import (
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
//sys convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) = advapi32.ConvertStringSecurityDescriptorToSecurityDescriptorW
|
|
||||||
//sys localFree(mem uintptr) = LocalFree
|
|
||||||
//sys getSecurityDescriptorLength(sd uintptr) (len uint32) = advapi32.GetSecurityDescriptorLength
|
|
||||||
//sys getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) = advapi32.GetSecurityInfo
|
|
||||||
//sys equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) = advapi32.EqualSid
|
|
||||||
|
|
||||||
const (
|
|
||||||
SE_FILE_OBJECT = 1
|
|
||||||
OWNER_SECURITY_INFORMATION = 1
|
|
||||||
)
|
|
||||||
|
|
||||||
func SddlToSecurityDescriptor(sddl string) ([]byte, error) {
|
|
||||||
var sdBuffer uintptr
|
|
||||||
err := convertStringSecurityDescriptorToSecurityDescriptor(sddl, 1, &sdBuffer, nil)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer localFree(sdBuffer)
|
|
||||||
sd := make([]byte, getSecurityDescriptorLength(sdBuffer))
|
|
||||||
copy(sd, (*[0xffff]byte)(unsafe.Pointer(sdBuffer))[:len(sd)])
|
|
||||||
return sd, nil
|
|
||||||
}
|
|
||||||
@@ -39,7 +39,6 @@ func errnoErr(e syscall.Errno) error {
|
|||||||
var (
|
var (
|
||||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||||
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
modntdll = windows.NewLazySystemDLL("ntdll.dll")
|
||||||
modadvapi32 = windows.NewLazySystemDLL("advapi32.dll")
|
|
||||||
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
modws2_32 = windows.NewLazySystemDLL("ws2_32.dll")
|
||||||
|
|
||||||
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
|
||||||
@@ -52,11 +51,6 @@ var (
|
|||||||
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
procRtlNtStatusToDosErrorNoTeb = modntdll.NewProc("RtlNtStatusToDosErrorNoTeb")
|
||||||
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
procRtlDosPathNameToNtPathName_U = modntdll.NewProc("RtlDosPathNameToNtPathName_U")
|
||||||
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
procRtlDefaultNpAcl = modntdll.NewProc("RtlDefaultNpAcl")
|
||||||
procConvertStringSecurityDescriptorToSecurityDescriptorW = modadvapi32.NewProc("ConvertStringSecurityDescriptorToSecurityDescriptorW")
|
|
||||||
procLocalFree = modkernel32.NewProc("LocalFree")
|
|
||||||
procGetSecurityDescriptorLength = modadvapi32.NewProc("GetSecurityDescriptorLength")
|
|
||||||
procGetSecurityInfo = modadvapi32.NewProc("GetSecurityInfo")
|
|
||||||
procEqualSid = modadvapi32.NewProc("EqualSid")
|
|
||||||
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
procCancelIoEx = modkernel32.NewProc("CancelIoEx")
|
||||||
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
procCreateIoCompletionPort = modkernel32.NewProc("CreateIoCompletionPort")
|
||||||
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
procGetQueuedCompletionStatus = modkernel32.NewProc("GetQueuedCompletionStatus")
|
||||||
@@ -64,7 +58,7 @@ var (
|
|||||||
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
procWSAGetOverlappedResult = modws2_32.NewProc("WSAGetOverlappedResult")
|
||||||
)
|
)
|
||||||
|
|
||||||
func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
func connectNamedPipe(pipe windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
r1, _, e1 := syscall.Syscall(procConnectNamedPipe.Addr(), 2, uintptr(pipe), uintptr(unsafe.Pointer(o)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -76,7 +70,7 @@ func connectNamedPipe(pipe syscall.Handle, o *syscall.Overlapped) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
var _p0 *uint16
|
var _p0 *uint16
|
||||||
_p0, err = syscall.UTF16PtrFromString(name)
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -85,10 +79,10 @@ func createNamedPipe(name string, flags uint32, pipeMode uint32, maxInstances ui
|
|||||||
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
return _createNamedPipe(_p0, flags, pipeMode, maxInstances, outSize, inSize, defaultTimeout, sa)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *syscall.SecurityAttributes) (handle syscall.Handle, err error) {
|
func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *windows.SecurityAttributes) (handle windows.Handle, err error) {
|
||||||
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
r0, _, e1 := syscall.Syscall9(procCreateNamedPipeW.Addr(), 8, uintptr(unsafe.Pointer(name)), uintptr(flags), uintptr(pipeMode), uintptr(maxInstances), uintptr(outSize), uintptr(inSize), uintptr(defaultTimeout), uintptr(unsafe.Pointer(sa)), 0)
|
||||||
handle = syscall.Handle(r0)
|
handle = windows.Handle(r0)
|
||||||
if handle == syscall.InvalidHandle {
|
if handle == windows.InvalidHandle {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
err = errnoErr(e1)
|
err = errnoErr(e1)
|
||||||
} else {
|
} else {
|
||||||
@@ -98,7 +92,7 @@ func _createNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
func createFile(name string, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
var _p0 *uint16
|
var _p0 *uint16
|
||||||
_p0, err = syscall.UTF16PtrFromString(name)
|
_p0, err = syscall.UTF16PtrFromString(name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -107,10 +101,10 @@ func createFile(name string, access uint32, mode uint32, sa *syscall.SecurityAtt
|
|||||||
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
return _createFile(_p0, access, mode, sa, createmode, attrs, templatefile)
|
||||||
}
|
}
|
||||||
|
|
||||||
func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityAttributes, createmode uint32, attrs uint32, templatefile syscall.Handle) (handle syscall.Handle, err error) {
|
func _createFile(name *uint16, access uint32, mode uint32, sa *windows.SecurityAttributes, createmode uint32, attrs uint32, templatefile windows.Handle) (handle windows.Handle, err error) {
|
||||||
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
r0, _, e1 := syscall.Syscall9(procCreateFileW.Addr(), 7, uintptr(unsafe.Pointer(name)), uintptr(access), uintptr(mode), uintptr(unsafe.Pointer(sa)), uintptr(createmode), uintptr(attrs), uintptr(templatefile), 0, 0)
|
||||||
handle = syscall.Handle(r0)
|
handle = windows.Handle(r0)
|
||||||
if handle == syscall.InvalidHandle {
|
if handle == windows.InvalidHandle {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
err = errnoErr(e1)
|
err = errnoErr(e1)
|
||||||
} else {
|
} else {
|
||||||
@@ -120,7 +114,7 @@ func _createFile(name *uint16, access uint32, mode uint32, sa *syscall.SecurityA
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
func getNamedPipeInfo(pipe windows.Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
r1, _, e1 := syscall.Syscall6(procGetNamedPipeInfo.Addr(), 5, uintptr(pipe), uintptr(unsafe.Pointer(flags)), uintptr(unsafe.Pointer(outSize)), uintptr(unsafe.Pointer(inSize)), uintptr(unsafe.Pointer(maxInstances)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -132,7 +126,7 @@ func getNamedPipeInfo(pipe syscall.Handle, flags *uint32, outSize *uint32, inSiz
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getNamedPipeHandleState(pipe syscall.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
func getNamedPipeHandleState(pipe windows.Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
r1, _, e1 := syscall.Syscall9(procGetNamedPipeHandleStateW.Addr(), 7, uintptr(pipe), uintptr(unsafe.Pointer(state)), uintptr(unsafe.Pointer(curInstances)), uintptr(unsafe.Pointer(maxCollectionCount)), uintptr(unsafe.Pointer(collectDataTimeout)), uintptr(unsafe.Pointer(userName)), uintptr(maxUserNameSize), 0, 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -150,7 +144,7 @@ func localAlloc(uFlags uint32, length uint32) (ptr uintptr) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func ntCreateNamedPipeFile(pipe *syscall.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
func ntCreateNamedPipeFile(pipe *windows.Handle, access uint32, oa *objectAttributes, iosb *ioStatusBlock, share uint32, disposition uint32, options uint32, typ uint32, readMode uint32, completionMode uint32, maxInstances uint32, inboundQuota uint32, outputQuota uint32, timeout *int64) (status ntstatus) {
|
||||||
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
r0, _, _ := syscall.Syscall15(procNtCreateNamedPipeFile.Addr(), 14, uintptr(unsafe.Pointer(pipe)), uintptr(access), uintptr(unsafe.Pointer(oa)), uintptr(unsafe.Pointer(iosb)), uintptr(share), uintptr(disposition), uintptr(options), uintptr(typ), uintptr(readMode), uintptr(completionMode), uintptr(maxInstances), uintptr(inboundQuota), uintptr(outputQuota), uintptr(unsafe.Pointer(timeout)), 0)
|
||||||
status = ntstatus(r0)
|
status = ntstatus(r0)
|
||||||
return
|
return
|
||||||
@@ -176,53 +170,7 @@ func rtlDefaultNpAcl(dacl *uintptr) (status ntstatus) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertStringSecurityDescriptorToSecurityDescriptor(str string, revision uint32, sd *uintptr, size *uint32) (err error) {
|
func cancelIoEx(file windows.Handle, o *windows.Overlapped) (err error) {
|
||||||
var _p0 *uint16
|
|
||||||
_p0, err = syscall.UTF16PtrFromString(str)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
return _convertStringSecurityDescriptorToSecurityDescriptor(_p0, revision, sd, size)
|
|
||||||
}
|
|
||||||
|
|
||||||
func _convertStringSecurityDescriptorToSecurityDescriptor(str *uint16, revision uint32, sd *uintptr, size *uint32) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall6(procConvertStringSecurityDescriptorToSecurityDescriptorW.Addr(), 4, uintptr(unsafe.Pointer(str)), uintptr(revision), uintptr(unsafe.Pointer(sd)), uintptr(unsafe.Pointer(size)), 0, 0)
|
|
||||||
if r1 == 0 {
|
|
||||||
if e1 != 0 {
|
|
||||||
err = errnoErr(e1)
|
|
||||||
} else {
|
|
||||||
err = syscall.EINVAL
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func localFree(mem uintptr) {
|
|
||||||
syscall.Syscall(procLocalFree.Addr(), 1, uintptr(mem), 0, 0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSecurityDescriptorLength(sd uintptr) (len uint32) {
|
|
||||||
r0, _, _ := syscall.Syscall(procGetSecurityDescriptorLength.Addr(), 1, uintptr(sd), 0, 0)
|
|
||||||
len = uint32(r0)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSecurityInfo(handle syscall.Handle, objectType uint32, securityInformation uint32, owner **syscall.SID, group **syscall.SID, dacl *uintptr, sacl *uintptr, sd *uintptr) (ret error) {
|
|
||||||
r0, _, _ := syscall.Syscall9(procGetSecurityInfo.Addr(), 8, uintptr(handle), uintptr(objectType), uintptr(securityInformation), uintptr(unsafe.Pointer(owner)), uintptr(unsafe.Pointer(group)), uintptr(unsafe.Pointer(dacl)), uintptr(unsafe.Pointer(sacl)), uintptr(unsafe.Pointer(sd)), 0)
|
|
||||||
if r0 != 0 {
|
|
||||||
ret = syscall.Errno(r0)
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func equalSid(sid1 *syscall.SID, sid2 *syscall.SID) (isEqual bool) {
|
|
||||||
r0, _, _ := syscall.Syscall(procEqualSid.Addr(), 2, uintptr(unsafe.Pointer(sid1)), uintptr(unsafe.Pointer(sid2)), 0)
|
|
||||||
isEqual = r0 != 0
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
|
||||||
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
r1, _, e1 := syscall.Syscall(procCancelIoEx.Addr(), 2, uintptr(file), uintptr(unsafe.Pointer(o)), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -234,9 +182,9 @@ func cancelIoEx(file syscall.Handle, o *syscall.Overlapped) (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintptr, threadCount uint32) (newport syscall.Handle, err error) {
|
func createIoCompletionPort(file windows.Handle, port windows.Handle, key uintptr, threadCount uint32) (newport windows.Handle, err error) {
|
||||||
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
r0, _, e1 := syscall.Syscall6(procCreateIoCompletionPort.Addr(), 4, uintptr(file), uintptr(port), uintptr(key), uintptr(threadCount), 0, 0)
|
||||||
newport = syscall.Handle(r0)
|
newport = windows.Handle(r0)
|
||||||
if newport == 0 {
|
if newport == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
err = errnoErr(e1)
|
err = errnoErr(e1)
|
||||||
@@ -247,7 +195,7 @@ func createIoCompletionPort(file syscall.Handle, port syscall.Handle, key uintpt
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
func getQueuedCompletionStatus(port windows.Handle, bytes *uint32, key *uintptr, o **ioOperation, timeout uint32) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
r1, _, e1 := syscall.Syscall6(procGetQueuedCompletionStatus.Addr(), 5, uintptr(port), uintptr(unsafe.Pointer(bytes)), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(o)), uintptr(timeout), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -259,7 +207,7 @@ func getQueuedCompletionStatus(port syscall.Handle, bytes *uint32, key *uintptr,
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err error) {
|
func setFileCompletionNotificationModes(h windows.Handle, flags uint8) (err error) {
|
||||||
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
r1, _, e1 := syscall.Syscall(procSetFileCompletionNotificationModes.Addr(), 2, uintptr(h), uintptr(flags), 0)
|
||||||
if r1 == 0 {
|
if r1 == 0 {
|
||||||
if e1 != 0 {
|
if e1 != 0 {
|
||||||
@@ -271,7 +219,7 @@ func setFileCompletionNotificationModes(h syscall.Handle, flags uint8) (err erro
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func wsaGetOverlappedResult(h syscall.Handle, o *syscall.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
func wsaGetOverlappedResult(h windows.Handle, o *windows.Overlapped, bytes *uint32, wait bool, flags *uint32) (err error) {
|
||||||
var _p0 uint32
|
var _p0 uint32
|
||||||
if wait {
|
if wait {
|
||||||
_p0 = 1
|
_p0 = 1
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ func main() {
|
|||||||
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
logger.Info.Println("Starting wireguard-go version", device.WireGuardGoVersion)
|
||||||
logger.Debug.Println("Debug log enabled")
|
logger.Debug.Println("Debug log enabled")
|
||||||
|
|
||||||
tun, err := tun.CreateTUN(interfaceName)
|
tun, err := tun.CreateTUN(interfaceName, 0)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
realInterfaceName, err2 := tun.Name()
|
realInterfaceName, err2 := tun.Name()
|
||||||
if err2 == nil {
|
if err2 == nil {
|
||||||
|
|||||||
@@ -36,7 +36,7 @@ func TestRatelimiter(t *testing.T) {
|
|||||||
for i := 0; i < packetsBurstable; i++ {
|
for i := 0; i < packetsBurstable; i++ {
|
||||||
Add(RatelimiterResult{
|
Add(RatelimiterResult{
|
||||||
allowed: true,
|
allowed: true,
|
||||||
text: "inital burst",
|
text: "initial burst",
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -60,7 +60,13 @@ func (rw *RWCancel) ReadyRead() bool {
|
|||||||
fdset := fdSet{}
|
fdset := fdSet{}
|
||||||
fdset.set(rw.fd)
|
fdset.set(rw.fd)
|
||||||
fdset.set(closeFd)
|
fdset.set(closeFd)
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
var err error
|
||||||
|
for {
|
||||||
|
err = unixSelect(max(rw.fd, closeFd)+1, &fdset.FdSet, nil, nil, nil)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -75,7 +81,13 @@ func (rw *RWCancel) ReadyWrite() bool {
|
|||||||
fdset := fdSet{}
|
fdset := fdSet{}
|
||||||
fdset.set(rw.fd)
|
fdset.set(rw.fd)
|
||||||
fdset.set(closeFd)
|
fdset.set(closeFd)
|
||||||
err := unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
var err error
|
||||||
|
for {
|
||||||
|
err = unixSelect(max(rw.fd, closeFd)+1, nil, &fdset.FdSet, nil, nil)
|
||||||
|
if err == nil || !RetryAfterError(err) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
@@ -42,6 +43,22 @@ type NativeTun struct {
|
|||||||
|
|
||||||
var sockaddrCtlSize uintptr = 32
|
var sockaddrCtlSize uintptr = 32
|
||||||
|
|
||||||
|
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||||
|
for i := 0; i < 20; i++ {
|
||||||
|
iface, err = net.InterfaceByIndex(index)
|
||||||
|
if err != nil {
|
||||||
|
if opErr, ok := err.(*net.OpError); ok {
|
||||||
|
if syscallErr, ok := opErr.Err.(*os.SyscallError); ok && syscallErr.Err == syscall.ENOMEM {
|
||||||
|
time.Sleep(time.Duration(i) * time.Second / 3)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return iface, err
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
||||||
var (
|
var (
|
||||||
statusUp bool
|
statusUp bool
|
||||||
@@ -74,7 +91,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
iface, err := net.InterfaceByIndex(ifindex)
|
iface, err := retryInterfaceByIndex(ifindex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.errors <- err
|
tun.errors <- err
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ type NativeTun struct {
|
|||||||
name string // name of interface
|
name string // name of interface
|
||||||
errors chan error // async error handling
|
errors chan error // async error handling
|
||||||
events chan Event // device related events
|
events chan Event // device related events
|
||||||
nopi bool // the device was pased IFF_NO_PI
|
nopi bool // the device was passed IFF_NO_PI
|
||||||
netlinkSock int
|
netlinkSock int
|
||||||
netlinkCancel *rwcancel.RWCancel
|
netlinkCancel *rwcancel.RWCancel
|
||||||
hackListenerClosed sync.Mutex
|
hackListenerClosed sync.Mutex
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
@@ -35,11 +36,12 @@ type NativeTun struct {
|
|||||||
wt *wintun.Interface
|
wt *wintun.Interface
|
||||||
handle windows.Handle
|
handle windows.Handle
|
||||||
close bool
|
close bool
|
||||||
rings wintun.RingDescriptor
|
|
||||||
events chan Event
|
events chan Event
|
||||||
errors chan error
|
errors chan error
|
||||||
forcedMTU int
|
forcedMTU int
|
||||||
rate rateJuggler
|
rate rateJuggler
|
||||||
|
rings *wintun.RingDescriptor
|
||||||
|
writeLock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
const WintunPool = wintun.Pool("WireGuard")
|
const WintunPool = wintun.Pool("WireGuard")
|
||||||
@@ -54,15 +56,15 @@ func nanotime() int64
|
|||||||
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
||||||
// interface with the same name exist, it is reused.
|
// interface with the same name exist, it is reused.
|
||||||
//
|
//
|
||||||
func CreateTUN(ifname string) (Device, error) {
|
func CreateTUN(ifname string, mtu int) (Device, error) {
|
||||||
return CreateTUNWithRequestedGUID(ifname, nil)
|
return CreateTUNWithRequestedGUID(ifname, nil, mtu)
|
||||||
}
|
}
|
||||||
|
|
||||||
//
|
//
|
||||||
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
||||||
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
||||||
//
|
//
|
||||||
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Device, error) {
|
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
||||||
var err error
|
var err error
|
||||||
var wt *wintun.Interface
|
var wt *wintun.Interface
|
||||||
|
|
||||||
@@ -72,12 +74,17 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
|
|||||||
// If so, we delete it, in case it has weird residual configuration.
|
// If so, we delete it, in case it has weird residual configuration.
|
||||||
_, err = wt.DeleteInterface()
|
_, err = wt.DeleteInterface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to delete already existing Wintun interface: %v", err)
|
return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
|
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Unable to create Wintun interface: %v", err)
|
return nil, fmt.Errorf("Error creating interface: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
forcedMTU := 1420
|
||||||
|
if mtu > 0 {
|
||||||
|
forcedMTU = mtu
|
||||||
}
|
}
|
||||||
|
|
||||||
tun := &NativeTun{
|
tun := &NativeTun{
|
||||||
@@ -85,16 +92,16 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID) (Dev
|
|||||||
handle: windows.InvalidHandle,
|
handle: windows.InvalidHandle,
|
||||||
events: make(chan Event, 10),
|
events: make(chan Event, 10),
|
||||||
errors: make(chan error, 1),
|
errors: make(chan error, 1),
|
||||||
forcedMTU: 1500,
|
forcedMTU: forcedMTU,
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tun.rings.Init()
|
tun.rings, err = wintun.NewRingDescriptor()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.Close()
|
tun.Close()
|
||||||
return nil, fmt.Errorf("Error creating events: %v", err)
|
return nil, fmt.Errorf("Error creating events: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
tun.handle, err = tun.wt.Register(&tun.rings)
|
tun.handle, err = tun.wt.Register(tun.rings)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tun.Close()
|
tun.Close()
|
||||||
return nil, fmt.Errorf("Error registering rings: %v", err)
|
return nil, fmt.Errorf("Error registering rings: %v", err)
|
||||||
@@ -214,6 +221,9 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|||||||
tun.rate.update(uint64(packetSize))
|
tun.rate.update(uint64(packetSize))
|
||||||
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
|
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
|
||||||
|
|
||||||
|
tun.writeLock.Lock()
|
||||||
|
defer tun.writeLock.Unlock()
|
||||||
|
|
||||||
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
|
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
|
||||||
if buffHead >= wintun.PacketCapacity {
|
if buffHead >= wintun.PacketCapacity {
|
||||||
return 0, os.ErrClosed
|
return 0, os.ErrClosed
|
||||||
@@ -244,6 +254,11 @@ func (tun *NativeTun) LUID() uint64 {
|
|||||||
return tun.wt.LUID()
|
return tun.wt.LUID()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Version returns the version of the Wintun driver and NDIS system currently loaded.
|
||||||
|
func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
|
||||||
|
return tun.wt.Version()
|
||||||
|
}
|
||||||
|
|
||||||
func (rate *rateJuggler) update(packetLen uint64) {
|
func (rate *rateJuggler) update(packetLen uint64) {
|
||||||
now := nanotime()
|
now := nanotime()
|
||||||
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
|
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package iphlpapi
|
package iphlpapi
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go conversion_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
"golang.org/x/text/unicode/norm"
|
"golang.org/x/text/unicode/norm"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/ipc/winpipe"
|
|
||||||
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
|
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -32,13 +31,13 @@ func initializeNamespace() error {
|
|||||||
if hasInitializedNamespace {
|
if hasInitializedNamespace {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
sd, err := winpipe.SddlToSecurityDescriptor("O:SYD:P(A;;GA;;;SY)")
|
sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
|
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
|
||||||
}
|
}
|
||||||
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
|
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
|
||||||
Length: uint32(len(sd)),
|
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||||
SecurityDescriptor: uintptr(unsafe.Pointer(&sd[0])),
|
SecurityDescriptor: sd,
|
||||||
}
|
}
|
||||||
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
|
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package namespaceapi
|
package namespaceapi
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go namespaceapi_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go
|
||||||
|
|||||||
@@ -40,7 +40,7 @@ func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
|
|||||||
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
|
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PrivateNamespace represents a private namespace. Duh?!
|
// PrivateNamespace represents a private namespace.
|
||||||
type PrivateNamespace windows.Handle
|
type PrivateNamespace windows.Handle
|
||||||
|
|
||||||
// CreatePrivateNamespace creates a private namespace.
|
// CreatePrivateNamespace creates a private namespace.
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package nci
|
package nci
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsyscall_windows.go nci_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package registry
|
package registry
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zregistry_windows.go registry_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zregistry_windows.go registry_windows.go
|
||||||
|
|||||||
@@ -6,6 +6,7 @@
|
|||||||
package wintun
|
package wintun
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"runtime"
|
||||||
"unsafe"
|
"unsafe"
|
||||||
|
|
||||||
"golang.org/x/sys/windows"
|
"golang.org/x/sys/windows"
|
||||||
@@ -53,25 +54,44 @@ func PacketAlign(size uint32) uint32 {
|
|||||||
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
|
return (size + (PacketAlignment - 1)) &^ (PacketAlignment - 1)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (descriptor *RingDescriptor) Init() (err error) {
|
func NewRingDescriptor() (descriptor *RingDescriptor, err error) {
|
||||||
|
descriptor = new(RingDescriptor)
|
||||||
|
allocatedRegion, err := windows.VirtualAlloc(0, unsafe.Sizeof(Ring{})*2, windows.MEM_COMMIT|windows.MEM_RESERVE, windows.PAGE_READWRITE)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
descriptor.free()
|
||||||
|
descriptor = nil
|
||||||
|
}
|
||||||
|
}()
|
||||||
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
|
descriptor.Send.Size = uint32(unsafe.Sizeof(Ring{}))
|
||||||
descriptor.Send.Ring = &Ring{}
|
descriptor.Send.Ring = (*Ring)(unsafe.Pointer(allocatedRegion))
|
||||||
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
descriptor.Send.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
|
descriptor.Receive.Size = uint32(unsafe.Sizeof(Ring{}))
|
||||||
descriptor.Receive.Ring = &Ring{}
|
descriptor.Receive.Ring = (*Ring)(unsafe.Pointer(allocatedRegion + unsafe.Sizeof(Ring{})))
|
||||||
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
descriptor.Receive.TailMoved, err = windows.CreateEvent(nil, 0, 0, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
windows.CloseHandle(descriptor.Send.TailMoved)
|
windows.CloseHandle(descriptor.Send.TailMoved)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
runtime.SetFinalizer(descriptor, func(d *RingDescriptor) { d.free() })
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (descriptor *RingDescriptor) free() {
|
||||||
|
if descriptor.Send.Ring != nil {
|
||||||
|
windows.VirtualFree(uintptr(unsafe.Pointer(descriptor.Send.Ring)), 0, windows.MEM_RELEASE)
|
||||||
|
descriptor.Send.Ring = nil
|
||||||
|
descriptor.Receive.Ring = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (descriptor *RingDescriptor) Close() {
|
func (descriptor *RingDescriptor) Close() {
|
||||||
if descriptor.Send.TailMoved != 0 {
|
if descriptor.Send.TailMoved != 0 {
|
||||||
windows.CloseHandle(descriptor.Send.TailMoved)
|
windows.CloseHandle(descriptor.Send.TailMoved)
|
||||||
|
|||||||
@@ -5,4 +5,4 @@
|
|||||||
|
|
||||||
package setupapi
|
package setupapi
|
||||||
|
|
||||||
//go:generate go run $GOROOT/src/syscall/mksyscall_windows.go -output zsetupapi_windows.go setupapi_windows.go
|
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsetupapi_windows.go setupapi_windows.go
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ type DevInfoData struct {
|
|||||||
_ uintptr
|
_ uintptr
|
||||||
}
|
}
|
||||||
|
|
||||||
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supercedes the functionality of SetupDiGetDeviceInfoListClass).
|
// DevInfoListDetailData is a structure for detailed information on a device information set (used for SetupDiGetDeviceInfoListDetail which supersedes the functionality of SetupDiGetDeviceInfoListClass).
|
||||||
type DevInfoListDetailData struct {
|
type DevInfoListDetailData struct {
|
||||||
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
|
size uint32 // Warning: unsafe.Sizeof(DevInfoListDetailData) > sizeof(SP_DEVINFO_LIST_DETAIL_DATA) when GOARCH == 386 => use sizeofDevInfoListDetailData const.
|
||||||
ClassGUID windows.GUID
|
ClassGUID windows.GUID
|
||||||
|
|||||||
@@ -40,9 +40,9 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// makeWintun creates a Wintun interface handle and populates it from the device's registry key.
|
// makeWintun creates a Wintun interface handle and populates it from the device's registry key.
|
||||||
func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
|
func makeWintun(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData, pool Pool) (*Interface, error) {
|
||||||
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
|
// Open HKEY_LOCAL_MACHINE\SYSTEM\CurrentControlSet\Control\Class\<class>\<id> registry key.
|
||||||
key, err := deviceInfoSet.OpenDevRegKey(deviceInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
|
key, err := devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.QUERY_VALUE)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
|
return nil, fmt.Errorf("Device-specific registry key open failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -72,7 +72,7 @@ func makeWintun(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfo
|
|||||||
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
|
return nil, fmt.Errorf("RegQueryValue(\"*IfType\") failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
instanceID, err := deviceInfoSet.DeviceInstanceID(deviceInfoData)
|
instanceID, err := devInfo.DeviceInstanceID(devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
|
return nil, fmt.Errorf("DeviceInstanceID failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -109,11 +109,11 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Create a list of network devices.
|
// Create a list of network devices.
|
||||||
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
|
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
|
return nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err)
|
||||||
}
|
}
|
||||||
defer devInfoList.Close()
|
defer devInfo.Close()
|
||||||
|
|
||||||
// Windows requires each interface to have a different name. When
|
// Windows requires each interface to have a different name. When
|
||||||
// enforcing this, Windows treats interface names case-insensitive. If an
|
// enforcing this, Windows treats interface names case-insensitive. If an
|
||||||
@@ -123,7 +123,7 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
|
|||||||
ifname = strings.ToLower(ifname)
|
ifname = strings.ToLower(ifname)
|
||||||
|
|
||||||
for index := 0; ; index++ {
|
for index := 0; ; index++ {
|
||||||
deviceData, err := devInfoList.EnumDeviceInfo(index)
|
devInfoData, err := devInfo.EnumDeviceInfo(index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == windows.ERROR_NO_MORE_ITEMS {
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
@@ -131,7 +131,16 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
wintun, err := makeWintun(devInfoList, deviceData, pool)
|
// Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
|
||||||
|
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
wintun, err := makeWintun(devInfo, devInfoData, pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -145,14 +154,14 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
|
|||||||
ifname3 := removeNumberedSuffix(ifname2)
|
ifname3 := removeNumberedSuffix(ifname2)
|
||||||
|
|
||||||
if ifname == ifname2 || ifname == ifname3 {
|
if ifname == ifname2 || ifname == ifname3 {
|
||||||
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
|
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
|
return nil, fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
|
||||||
}
|
}
|
||||||
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
|
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
|
||||||
|
|
||||||
for index := 0; ; index++ {
|
for index := 0; ; index++ {
|
||||||
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
|
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == windows.ERROR_NO_MORE_ITEMS {
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
@@ -161,13 +170,13 @@ func (pool Pool) GetInterface(ifname string) (*Interface, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get driver info details.
|
// Get driver info details.
|
||||||
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
|
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if driverDetailData.IsCompatible(hardwareID) {
|
if driverDetailData.IsCompatible(hardwareID) {
|
||||||
isMember, err := pool.isMember(devInfoList, deviceData)
|
isMember, err := pool.isMember(devInfo, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -206,12 +215,12 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Create an empty device info set for network adapter device class.
|
// Create an empty device info set for network adapter device class.
|
||||||
devInfoList, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
|
devInfo, err := setupapi.SetupDiCreateDeviceInfoListEx(&deviceClassNetGUID, 0, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
|
err = fmt.Errorf("SetupDiCreateDeviceInfoListEx(%v) failed: %v", deviceClassNetGUID, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer devInfoList.Close()
|
defer devInfo.Close()
|
||||||
|
|
||||||
// Get the device class name from GUID.
|
// Get the device class name from GUID.
|
||||||
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
|
className, err := setupapi.SetupDiClassNameFromGuidEx(&deviceClassNetGUID, "")
|
||||||
@@ -222,43 +231,43 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
|
|
||||||
// Create a new device info element and add it to the device info set.
|
// Create a new device info element and add it to the device info set.
|
||||||
deviceTypeName := pool.deviceTypeName()
|
deviceTypeName := pool.deviceTypeName()
|
||||||
deviceData, err := devInfoList.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
|
devInfoData, err := devInfo.CreateDeviceInfo(className, &deviceClassNetGUID, deviceTypeName, 0, setupapi.DICD_GENERATE_ID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
|
err = fmt.Errorf("SetupDiCreateDeviceInfo failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = setQuietInstall(devInfoList, deviceData)
|
err = setQuietInstall(devInfo, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("Setting quiet installation failed: %v", err)
|
err = fmt.Errorf("Setting quiet installation failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set a device information element as the selected member of a device information set.
|
// Set a device information element as the selected member of a device information set.
|
||||||
err = devInfoList.SetSelectedDevice(deviceData)
|
err = devInfo.SetSelectedDevice(devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
|
err = fmt.Errorf("SetupDiSetSelectedDevice failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Set Plug&Play device hardware ID property.
|
// Set Plug&Play device hardware ID property.
|
||||||
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_HARDWAREID, hardwareID)
|
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_HARDWAREID, hardwareID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
|
err = fmt.Errorf("SetupDiSetDeviceRegistryProperty(SPDRP_HARDWAREID) failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
|
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER) // TODO: This takes ~510ms
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
|
err = fmt.Errorf("SetupDiBuildDriverInfoList failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
|
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
|
||||||
|
|
||||||
driverDate := windows.Filetime{}
|
driverDate := windows.Filetime{}
|
||||||
driverVersion := uint64(0)
|
driverVersion := uint64(0)
|
||||||
for index := 0; ; index++ { // TODO: This loop takes ~600ms
|
for index := 0; ; index++ { // TODO: This loop takes ~600ms
|
||||||
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, index)
|
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == windows.ERROR_NO_MORE_ITEMS {
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
@@ -268,13 +277,13 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
|
|
||||||
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
|
// Check the driver version first, since the check is trivial and will save us iterating over hardware IDs for any driver versioned prior our best match.
|
||||||
if driverData.IsNewer(driverDate, driverVersion) {
|
if driverData.IsNewer(driverDate, driverVersion) {
|
||||||
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
|
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if driverDetailData.IsCompatible(hardwareID) {
|
if driverDetailData.IsCompatible(hardwareID) {
|
||||||
err := devInfoList.SetSelectedDriver(deviceData, driverData)
|
err := devInfo.SetSelectedDriver(devInfoData, driverData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -299,10 +308,10 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set class installer parameters for DIF_REMOVE.
|
// Set class installer parameters for DIF_REMOVE.
|
||||||
if devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
|
if devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams))) == nil {
|
||||||
// Call appropriate class installer.
|
// Call appropriate class installer.
|
||||||
if devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData) == nil {
|
if devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData) == nil {
|
||||||
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
|
rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,14 +320,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Call appropriate class installer.
|
// Call appropriate class installer.
|
||||||
err = devInfoList.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, deviceData)
|
err = devInfo.CallClassInstaller(setupapi.DIF_REGISTERDEVICE, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
|
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_REGISTERDEVICE) failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Register device co-installers if any. (Ignore errors)
|
// Register device co-installers if any. (Ignore errors)
|
||||||
devInfoList.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, deviceData)
|
devInfo.CallClassInstaller(setupapi.DIF_REGISTER_COINSTALLERS, devInfoData)
|
||||||
|
|
||||||
var netDevRegKey registry.Key
|
var netDevRegKey registry.Key
|
||||||
const pollTimeout = time.Millisecond * 50
|
const pollTimeout = time.Millisecond * 50
|
||||||
@@ -326,7 +335,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
if i != 0 {
|
if i != 0 {
|
||||||
time.Sleep(pollTimeout)
|
time.Sleep(pollTimeout)
|
||||||
}
|
}
|
||||||
netDevRegKey, err = devInfoList.OpenDevRegKey(deviceData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
|
netDevRegKey, err = devInfo.OpenDevRegKey(devInfoData, setupapi.DICS_FLAG_GLOBAL, 0, setupapi.DIREG_DRV, registry.SET_VALUE|registry.QUERY_VALUE|registry.NOTIFY)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -345,17 +354,17 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Install interfaces if any. (Ignore errors)
|
// Install interfaces if any. (Ignore errors)
|
||||||
devInfoList.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, deviceData)
|
devInfo.CallClassInstaller(setupapi.DIF_INSTALLINTERFACES, devInfoData)
|
||||||
|
|
||||||
// Install the device.
|
// Install the device.
|
||||||
err = devInfoList.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, deviceData)
|
err = devInfo.CallClassInstaller(setupapi.DIF_INSTALLDEVICE, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
|
err = fmt.Errorf("SetupDiCallClassInstaller(DIF_INSTALLDEVICE) failed: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
rebootRequired = checkReboot(devInfoList, deviceData)
|
rebootRequired = checkReboot(devInfo, devInfoData)
|
||||||
|
|
||||||
err = devInfoList.SetDeviceRegistryPropertyString(deviceData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
|
err = devInfo.SetDeviceRegistryPropertyString(devInfoData, setupapi.SPDRP_DEVICEDESC, deviceTypeName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
|
err = fmt.Errorf("SetDeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
|
||||||
return
|
return
|
||||||
@@ -381,7 +390,7 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get network interface.
|
// Get network interface.
|
||||||
wintun, err = makeWintun(devInfoList, deviceData, pool)
|
wintun, err = makeWintun(devInfo, devInfoData, pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("makeWintun failed: %v", err)
|
err = fmt.Errorf("makeWintun failed: %v", err)
|
||||||
return
|
return
|
||||||
@@ -435,14 +444,14 @@ func (pool Pool) CreateInterface(ifname string, requestedGUID *windows.GUID) (wi
|
|||||||
// if the interface was not found. It returns a bool indicating whether
|
// if the interface was not found. It returns a bool indicating whether
|
||||||
// a reboot is required.
|
// a reboot is required.
|
||||||
func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
|
func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
|
||||||
devInfoList, deviceData, err := wintun.deviceData()
|
devInfo, devInfoData, err := wintun.devInfoData()
|
||||||
if err == windows.ERROR_OBJECT_NOT_FOUND {
|
if err == windows.ERROR_OBJECT_NOT_FOUND {
|
||||||
return false, nil
|
return false, nil
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return false, err
|
||||||
}
|
}
|
||||||
defer devInfoList.Close()
|
defer devInfo.Close()
|
||||||
|
|
||||||
// Remove the device.
|
// Remove the device.
|
||||||
removeDeviceParams := setupapi.RemoveDeviceParams{
|
removeDeviceParams := setupapi.RemoveDeviceParams{
|
||||||
@@ -451,18 +460,18 @@ func (wintun *Interface) DeleteInterface() (rebootRequired bool, err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Set class installer parameters for DIF_REMOVE.
|
// Set class installer parameters for DIF_REMOVE.
|
||||||
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
|
err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
|
return false, fmt.Errorf("SetupDiSetClassInstallParams failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Call appropriate class installer.
|
// Call appropriate class installer.
|
||||||
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
|
err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
|
return false, fmt.Errorf("SetupDiCallClassInstaller failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return checkReboot(devInfoList, deviceData), nil
|
return checkReboot(devInfo, devInfoData), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeleteMatchingInterfaces deletes all Wintun interfaces, which match
|
// DeleteMatchingInterfaces deletes all Wintun interfaces, which match
|
||||||
@@ -479,14 +488,14 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
|
|||||||
windows.CloseHandle(mutex)
|
windows.CloseHandle(mutex)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
|
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
|
return nil, false, []error{fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())}
|
||||||
}
|
}
|
||||||
defer devInfoList.Close()
|
defer devInfo.Close()
|
||||||
|
|
||||||
for i := 0; ; i++ {
|
for i := 0; ; i++ {
|
||||||
deviceData, err := devInfoList.EnumDeviceInfo(i)
|
devInfoData, err := devInfo.EnumDeviceInfo(i)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == windows.ERROR_NO_MORE_ITEMS {
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
@@ -494,22 +503,31 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = devInfoList.BuildDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
|
// Check the Hardware ID to make sure it's a real Wintun device first. This avoids doing slow operations on non-Wintun devices.
|
||||||
|
property, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_HARDWAREID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
defer devInfoList.DestroyDriverInfoList(deviceData, setupapi.SPDIT_COMPATDRIVER)
|
if hwids, ok := property.([]string); ok && len(hwids) > 0 && hwids[0] != hardwareID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
err = devInfo.BuildDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
defer devInfo.DestroyDriverInfoList(devInfoData, setupapi.SPDIT_COMPATDRIVER)
|
||||||
|
|
||||||
isWintun := false
|
isWintun := false
|
||||||
for j := 0; ; j++ {
|
for j := 0; ; j++ {
|
||||||
driverData, err := devInfoList.EnumDriverInfo(deviceData, setupapi.SPDIT_COMPATDRIVER, j)
|
driverData, err := devInfo.EnumDriverInfo(devInfoData, setupapi.SPDIT_COMPATDRIVER, j)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == windows.ERROR_NO_MORE_ITEMS {
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
driverDetailData, err := devInfoList.DriverInfoDetail(deviceData, driverData)
|
driverDetailData, err := devInfo.DriverInfoDetail(devInfoData, driverData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -522,7 +540,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
isMember, err := pool.isMember(devInfoList, deviceData)
|
isMember, err := pool.isMember(devInfo, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, err)
|
errors = append(errors, err)
|
||||||
continue
|
continue
|
||||||
@@ -531,7 +549,7 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
wintun, err := makeWintun(devInfoList, deviceData, pool)
|
wintun, err := makeWintun(devInfo, devInfoData, pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
|
errors = append(errors, fmt.Errorf("Unable to make Wintun interface object: %v", err))
|
||||||
continue
|
continue
|
||||||
@@ -540,41 +558,41 @@ func (pool Pool) DeleteMatchingInterfaces(matches func(wintun *Interface) bool)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = setQuietInstall(devInfoList, deviceData)
|
err = setQuietInstall(devInfo, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, err)
|
errors = append(errors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
inst := deviceData.DevInst
|
inst := devInfoData.DevInst
|
||||||
removeDeviceParams := setupapi.RemoveDeviceParams{
|
removeDeviceParams := setupapi.RemoveDeviceParams{
|
||||||
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
|
ClassInstallHeader: *setupapi.MakeClassInstallHeader(setupapi.DIF_REMOVE),
|
||||||
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
|
Scope: setupapi.DI_REMOVEDEVICE_GLOBAL,
|
||||||
}
|
}
|
||||||
err = devInfoList.SetClassInstallParams(deviceData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
|
err = devInfo.SetClassInstallParams(devInfoData, &removeDeviceParams.ClassInstallHeader, uint32(unsafe.Sizeof(removeDeviceParams)))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, err)
|
errors = append(errors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
err = devInfoList.CallClassInstaller(setupapi.DIF_REMOVE, deviceData)
|
err = devInfo.CallClassInstaller(setupapi.DIF_REMOVE, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = append(errors, err)
|
errors = append(errors, err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
rebootRequired = rebootRequired || checkReboot(devInfoList, deviceData)
|
rebootRequired = rebootRequired || checkReboot(devInfo, devInfoData)
|
||||||
deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
|
deviceInstancesDeleted = append(deviceInstancesDeleted, inst)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
|
// isMember checks if SPDRP_DEVICEDESC or SPDRP_FRIENDLYNAME match device type name.
|
||||||
func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) (bool, error) {
|
func (pool Pool) isMember(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) (bool, error) {
|
||||||
deviceDescVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_DEVICEDESC)
|
deviceDescVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_DEVICEDESC)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
|
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_DEVICEDESC) failed: %v", err)
|
||||||
}
|
}
|
||||||
deviceDesc, _ := deviceDescVal.(string)
|
deviceDesc, _ := deviceDescVal.(string)
|
||||||
friendlyNameVal, err := deviceInfoSet.DeviceRegistryProperty(deviceInfoData, setupapi.SPDRP_FRIENDLYNAME)
|
friendlyNameVal, err := devInfo.DeviceRegistryProperty(devInfoData, setupapi.SPDRP_FRIENDLYNAME)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
|
return false, fmt.Errorf("DeviceRegistryPropertyString(SPDRP_FRIENDLYNAME) failed: %v", err)
|
||||||
}
|
}
|
||||||
@@ -585,8 +603,8 @@ func (pool Pool) isMember(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupa
|
|||||||
}
|
}
|
||||||
|
|
||||||
// checkReboot checks device install parameters if a system reboot is required.
|
// checkReboot checks device install parameters if a system reboot is required.
|
||||||
func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) bool {
|
func checkReboot(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) bool {
|
||||||
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
|
devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -595,14 +613,14 @@ func checkReboot(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInf
|
|||||||
}
|
}
|
||||||
|
|
||||||
// setQuietInstall sets device install parameters for a quiet installation
|
// setQuietInstall sets device install parameters for a quiet installation
|
||||||
func setQuietInstall(deviceInfoSet setupapi.DevInfo, deviceInfoData *setupapi.DevInfoData) error {
|
func setQuietInstall(devInfo setupapi.DevInfo, devInfoData *setupapi.DevInfoData) error {
|
||||||
devInstallParams, err := deviceInfoSet.DeviceInstallParams(deviceInfoData)
|
devInstallParams, err := devInfo.DeviceInstallParams(devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
|
devInstallParams.Flags |= setupapi.DI_QUIETINSTALL
|
||||||
return deviceInfoSet.SetDeviceInstallParams(deviceInfoData, devInstallParams)
|
return devInfo.SetDeviceInstallParams(devInfoData, devInstallParams)
|
||||||
}
|
}
|
||||||
|
|
||||||
// deviceTypeName returns pool-specific device type name.
|
// deviceTypeName returns pool-specific device type name.
|
||||||
@@ -671,11 +689,39 @@ func (wintun *Interface) tcpipAdapterRegKeyName() string {
|
|||||||
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
|
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\Tcpip\\Parameters\\Adapters\\%v", wintun.cfgInstanceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// deviceRegKeyName returns the device-level registry key name
|
// deviceRegKeyName returns the device-level registry key name.
|
||||||
func (wintun *Interface) deviceRegKeyName() string {
|
func (wintun *Interface) deviceRegKeyName() string {
|
||||||
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
|
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Enum\\%v", wintun.devInstanceID)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Version returns the version of the Wintun driver and NDIS system currently loaded.
|
||||||
|
func (wintun *Interface) Version() (driverVersion string, ndisVersion string, err error) {
|
||||||
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, "SYSTEM\\CurrentControlSet\\Services\\Wintun", registry.QUERY_VALUE)
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer key.Close()
|
||||||
|
driverMajor, _, err := key.GetIntegerValue("DriverMajorVersion")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
driverMinor, _, err := key.GetIntegerValue("DriverMinorVersion")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ndisMajor, _, err := key.GetIntegerValue("NdisMajorVersion")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ndisMinor, _, err := key.GetIntegerValue("NdisMinorVersion")
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
driverVersion = fmt.Sprintf("%d.%d", driverMajor, driverMinor)
|
||||||
|
ndisVersion = fmt.Sprintf("%d.%d", ndisMajor, ndisMinor)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
|
// tcpipInterfaceRegKeyName returns the interface-specific TCP/IP network registry key name.
|
||||||
func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
|
func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
|
||||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
|
key, err := registry.OpenKey(registry.LOCAL_MACHINE, wintun.tcpipAdapterRegKeyName(), registry.QUERY_VALUE)
|
||||||
@@ -693,18 +739,18 @@ func (wintun *Interface) tcpipInterfaceRegKeyName() (path string, err error) {
|
|||||||
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
|
return fmt.Sprintf("SYSTEM\\CurrentControlSet\\Services\\%s", paths[0]), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// deviceData returns TUN device info list handle and interface device info
|
// devInfoData returns TUN device info list handle and interface device info
|
||||||
// data. The device info list handle must be closed after use. In case the
|
// data. The device info list handle must be closed after use. In case the
|
||||||
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
|
// device is not found, windows.ERROR_OBJECT_NOT_FOUND is returned.
|
||||||
func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
|
func (wintun *Interface) devInfoData() (setupapi.DevInfo, *setupapi.DevInfoData, error) {
|
||||||
// Create a list of network devices.
|
// Create a list of network devices.
|
||||||
devInfoList, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
|
devInfo, err := setupapi.SetupDiGetClassDevsEx(&deviceClassNetGUID, "", 0, setupapi.DIGCF_PRESENT, setupapi.DevInfo(0), "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
|
return 0, nil, fmt.Errorf("SetupDiGetClassDevsEx(%v) failed: %v", deviceClassNetGUID, err.Error())
|
||||||
}
|
}
|
||||||
|
|
||||||
for index := 0; ; index++ {
|
for index := 0; ; index++ {
|
||||||
deviceData, err := devInfoList.EnumDeviceInfo(index)
|
devInfoData, err := devInfo.EnumDeviceInfo(index)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if err == windows.ERROR_NO_MORE_ITEMS {
|
if err == windows.ERROR_NO_MORE_ITEMS {
|
||||||
break
|
break
|
||||||
@@ -714,22 +760,22 @@ func (wintun *Interface) deviceData() (setupapi.DevInfo, *setupapi.DevInfoData,
|
|||||||
|
|
||||||
// Get interface ID.
|
// Get interface ID.
|
||||||
// TODO: Store some ID in the Wintun object such that this call isn't required.
|
// TODO: Store some ID in the Wintun object such that this call isn't required.
|
||||||
wintun2, err := makeWintun(devInfoList, deviceData, wintun.pool)
|
wintun2, err := makeWintun(devInfo, devInfoData, wintun.pool)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if wintun.cfgInstanceID == wintun2.cfgInstanceID {
|
if wintun.cfgInstanceID == wintun2.cfgInstanceID {
|
||||||
err = setQuietInstall(devInfoList, deviceData)
|
err = setQuietInstall(devInfo, devInfoData)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
devInfoList.Close()
|
devInfo.Close()
|
||||||
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
|
return 0, nil, fmt.Errorf("Setting quiet installation failed: %v", err)
|
||||||
}
|
}
|
||||||
return devInfoList, deviceData, nil
|
return devInfo, devInfoData, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
devInfoList.Close()
|
devInfo.Close()
|
||||||
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
|
return 0, nil, windows.ERROR_OBJECT_NOT_FOUND
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user