Compare commits
63 Commits
0.0.202001
...
0.0.202011
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
da19db415a | ||
|
|
52c834c446 | ||
|
|
913f68ce38 | ||
|
|
60b3766b89 | ||
|
|
82128c47d9 | ||
|
|
c192b2eeec | ||
|
|
a3b231b31e | ||
|
|
65e03a9182 | ||
|
|
3e08b8aee0 | ||
|
|
5ca1218a5c | ||
|
|
3b490f30aa | ||
|
|
e6b7c4eef3 | ||
|
|
8ae09213a7 | ||
|
|
36dc8b6994 | ||
|
|
2057f19a61 | ||
|
|
58a8f05f50 | ||
|
|
0b54907a73 | ||
|
|
2c143dce0f | ||
|
|
22af3890f6 | ||
|
|
c8fe925020 | ||
|
|
0cfa3314ee | ||
|
|
bc3f505efa | ||
|
|
507f148e1c | ||
|
|
31b574ef99 | ||
|
|
3c41141fb4 | ||
|
|
4369db522b | ||
|
|
b84f1d4db2 | ||
|
|
dfb28757f7 | ||
|
|
00bcd865e6 | ||
|
|
f28a6d244b | ||
|
|
c403da6a39 | ||
|
|
d6de6f3ce6 | ||
|
|
59e556f24e | ||
|
|
31faf4c159 | ||
|
|
99eb7896be | ||
|
|
f60b3919be | ||
|
|
da9d300cf8 | ||
|
|
59c9929714 | ||
|
|
db0aa39b76 | ||
|
|
bc77de2aca | ||
|
|
c8596328e7 | ||
|
|
28c4d04304 | ||
|
|
fdba6c183a | ||
|
|
250b9795f3 | ||
|
|
d60857e1a7 | ||
|
|
2fb0a712f0 | ||
|
|
f2c6faad44 | ||
|
|
c76b818466 | ||
|
|
de374bfb44 | ||
|
|
1a1c3d0968 | ||
|
|
85a45a9651 | ||
|
|
abd287159e | ||
|
|
203554620d | ||
|
|
6aefb61355 | ||
|
|
3dce460c88 | ||
|
|
224bc9e60c | ||
|
|
9cd8909df2 | ||
|
|
ae88e2a2cd | ||
|
|
4739708ca4 | ||
|
|
b33219c2cf | ||
|
|
9cbcff10dd | ||
|
|
6ed56ff2df | ||
|
|
cb4bb63030 |
5
Makefile
5
Makefile
@@ -22,7 +22,10 @@ wireguard-go: $(wildcard *.go) $(wildcard */*.go)
|
||||
install: wireguard-go
|
||||
@install -v -d "$(DESTDIR)$(BINDIR)" && install -v -m 0755 "$<" "$(DESTDIR)$(BINDIR)/wireguard-go"
|
||||
|
||||
test:
|
||||
go test -v ./...
|
||||
|
||||
clean:
|
||||
rm -f wireguard-go
|
||||
|
||||
.PHONY: all clean install generate-version-and-build
|
||||
.PHONY: all clean test install generate-version-and-build
|
||||
|
||||
@@ -26,7 +26,7 @@ To run with more logging you may set the environment variable `LOG_LEVEL=debug`.
|
||||
|
||||
### Linux
|
||||
|
||||
This will run on Linux; however **YOU SHOULD NOT RUN THIS ON LINUX**. Instead use the kernel module; see the [installation page](https://www.wireguard.com/install/) for instructions.
|
||||
This will run on Linux; however you should instead use the kernel module, which is faster and better integrated into the OS. See the [installation page](https://www.wireguard.com/install/) for instructions.
|
||||
|
||||
### macOS
|
||||
|
||||
@@ -46,7 +46,7 @@ This will run on OpenBSD. It does not yet support sticky sockets. Fwmark is mapp
|
||||
|
||||
## Building
|
||||
|
||||
This requires an installation of [go](https://golang.org) ≥ 1.12.
|
||||
This requires an installation of [go](https://golang.org) ≥ 1.13.
|
||||
|
||||
```
|
||||
$ git clone https://git.zx2c4.com/wireguard-go
|
||||
@@ -56,7 +56,7 @@ $ make
|
||||
|
||||
## License
|
||||
|
||||
Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy of
|
||||
this software and associated documentation files (the "Software"), to deal in
|
||||
|
||||
34
conn/boundif_android.go
Normal file
34
conn/boundif_android.go
Normal file
@@ -0,0 +1,34 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package conn
|
||||
|
||||
func (bind *nativeBind) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
err = sysconn.Control(func(f uintptr) {
|
||||
fd = int(f)
|
||||
})
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (bind *nativeBind) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
err = sysconn.Control(func(f uintptr) {
|
||||
fd = int(f)
|
||||
})
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,13 +1,12 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
package conn
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
@@ -18,17 +17,13 @@ const (
|
||||
sockoptIPV6_UNICAST_IF = 31
|
||||
)
|
||||
|
||||
func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
func (bind *nativeBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
|
||||
/* MSDN says for IPv4 this needs to be in net byte order, so that it's like an IP address with leading zeros. */
|
||||
bytes := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(bytes, interfaceIndex)
|
||||
interfaceIndex = *(*uint32)(unsafe.Pointer(&bytes[0]))
|
||||
|
||||
if device.net.bind == nil {
|
||||
return errors.New("Bind is not yet initialized")
|
||||
}
|
||||
|
||||
sysconn, err := device.net.bind.(*nativeBind).ipv4.SyscallConn()
|
||||
sysconn, err := bind.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -41,12 +36,12 @@ func (device *Device) BindSocketToInterface4(interfaceIndex uint32, blackhole bo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
device.net.bind.(*nativeBind).blackhole4 = blackhole
|
||||
bind.blackhole4 = blackhole
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
sysconn, err := device.net.bind.(*nativeBind).ipv6.SyscallConn()
|
||||
func (bind *nativeBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
|
||||
sysconn, err := bind.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -59,6 +54,6 @@ func (device *Device) BindSocketToInterface6(interfaceIndex uint32, blackhole bo
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
device.net.bind.(*nativeBind).blackhole6 = blackhole
|
||||
bind.blackhole6 = blackhole
|
||||
return nil
|
||||
}
|
||||
111
conn/conn.go
Normal file
111
conn/conn.go
Normal file
@@ -0,0 +1,111 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package conn implements WireGuard's network connections.
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
|
||||
//
|
||||
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
|
||||
// depending on the platform-specific implementation.
|
||||
type Bind interface {
|
||||
// LastMark reports the last mark set for this Bind.
|
||||
LastMark() uint32
|
||||
|
||||
// SetMark sets the mark for each packet sent through this Bind.
|
||||
// This mark is passed to the kernel as the socket option SO_MARK.
|
||||
SetMark(mark uint32) error
|
||||
|
||||
// ReceiveIPv6 reads an IPv6 UDP packet into b.
|
||||
//
|
||||
// It reports the number of bytes read, n,
|
||||
// the packet source address ep,
|
||||
// and any error.
|
||||
ReceiveIPv6(buff []byte) (n int, ep Endpoint, err error)
|
||||
|
||||
// ReceiveIPv4 reads an IPv4 UDP packet into b.
|
||||
//
|
||||
// It reports the number of bytes read, n,
|
||||
// the packet source address ep,
|
||||
// and any error.
|
||||
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
|
||||
|
||||
// Send writes a packet b to address ep.
|
||||
Send(b []byte, ep Endpoint) error
|
||||
|
||||
// Close closes the Bind connection.
|
||||
Close() error
|
||||
}
|
||||
|
||||
// CreateBind creates a Bind bound to a port.
|
||||
//
|
||||
// The value actualPort reports the actual port number the Bind
|
||||
// object gets bound to.
|
||||
func CreateBind(port uint16) (b Bind, actualPort uint16, err error) {
|
||||
return createBind(port)
|
||||
}
|
||||
|
||||
// BindSocketToInterface is implemented by Bind objects that support being
|
||||
// tied to a single network interface. Used by wireguard-windows.
|
||||
type BindSocketToInterface interface {
|
||||
BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error
|
||||
BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error
|
||||
}
|
||||
|
||||
// PeekLookAtSocketFd is implemented by Bind objects that support having their
|
||||
// file descriptor peeked at. Used by wireguard-android.
|
||||
type PeekLookAtSocketFd interface {
|
||||
PeekLookAtSocketFd4() (fd int, err error)
|
||||
PeekLookAtSocketFd6() (fd int, err error)
|
||||
}
|
||||
|
||||
// An Endpoint maintains the source/destination caching for a peer.
|
||||
//
|
||||
// dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||
// src : the local address from which datagrams originate going to the peer
|
||||
type Endpoint interface {
|
||||
ClearSrc() // clears the source address
|
||||
SrcToString() string // returns the local source address (ip:port)
|
||||
DstToString() string // returns the destination address (ip:port)
|
||||
DstToBytes() []byte // used for mac2 cookie calculations
|
||||
DstIP() net.IP
|
||||
SrcIP() net.IP
|
||||
}
|
||||
|
||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
// ensure that the host is an IP address
|
||||
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
||||
// trying to make sure with a small sanity test that this is a real IP address and
|
||||
// not something that's likely to incur DNS lookups.
|
||||
host = host[:i]
|
||||
}
|
||||
if ip := net.ParseIP(host); ip == nil {
|
||||
return nil, errors.New("Failed to parse IP address: " + host)
|
||||
}
|
||||
|
||||
// parse address and port
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip4 := addr.IP.To4()
|
||||
if ip4 != nil {
|
||||
addr.IP = ip4
|
||||
}
|
||||
return addr, err
|
||||
}
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
package conn
|
||||
|
||||
import (
|
||||
"net"
|
||||
@@ -67,16 +67,12 @@ func (e *NativeEndpoint) SrcToString() string {
|
||||
}
|
||||
|
||||
func listenNet(network string, port int) (*net.UDPConn, int, error) {
|
||||
|
||||
// listen
|
||||
|
||||
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
// retrieve port
|
||||
|
||||
// Retrieve port.
|
||||
laddr := conn.LocalAddr()
|
||||
uaddr, err := net.ResolveUDPAddr(
|
||||
laddr.Network(),
|
||||
@@ -100,7 +96,7 @@ func extractErrno(err error) error {
|
||||
return syscallErr.Err
|
||||
}
|
||||
|
||||
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
|
||||
func createBind(uport uint16) (Bind, uint16, error) {
|
||||
var err error
|
||||
var bind nativeBind
|
||||
|
||||
@@ -135,6 +131,8 @@ func (bind *nativeBind) Close() error {
|
||||
return err2
|
||||
}
|
||||
|
||||
func (bind *nativeBind) LastMark() uint32 { return 0 }
|
||||
|
||||
func (bind *nativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||
if bind.ipv4 == nil {
|
||||
return 0, nil, syscall.EAFNOSUPPORT
|
||||
@@ -2,19 +2,10 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
* of the sticky-sockets.c example code:
|
||||
* https://git.zx2c4.com/wireguard-tools/tree/contrib/sticky-sockets/sticky-sockets.c
|
||||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
package conn
|
||||
|
||||
import (
|
||||
"errors"
|
||||
@@ -25,7 +16,6 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -33,8 +23,8 @@ const (
|
||||
)
|
||||
|
||||
type IPv4Source struct {
|
||||
src [4]byte
|
||||
ifindex int32
|
||||
Src [4]byte
|
||||
Ifindex int32
|
||||
}
|
||||
|
||||
type IPv6Source struct {
|
||||
@@ -49,6 +39,10 @@ type NativeEndpoint struct {
|
||||
isV6 bool
|
||||
}
|
||||
|
||||
func (endpoint *NativeEndpoint) Src4() *IPv4Source { return endpoint.src4() }
|
||||
func (endpoint *NativeEndpoint) Dst4() *unix.SockaddrInet4 { return endpoint.dst4() }
|
||||
func (endpoint *NativeEndpoint) IsV6() bool { return endpoint.isV6 }
|
||||
|
||||
func (endpoint *NativeEndpoint) src4() *IPv4Source {
|
||||
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
|
||||
}
|
||||
@@ -68,8 +62,6 @@ func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
|
||||
type nativeBind struct {
|
||||
sock4 int
|
||||
sock6 int
|
||||
netlinkSock int
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
lastMark uint32
|
||||
}
|
||||
|
||||
@@ -111,59 +103,25 @@ func CreateEndpoint(s string) (Endpoint, error) {
|
||||
return nil, errors.New("Invalid IP address")
|
||||
}
|
||||
|
||||
func createNetlinkRouteSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
unix.Close(sock)
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
|
||||
}
|
||||
|
||||
func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
||||
func createBind(port uint16) (Bind, uint16, error) {
|
||||
var err error
|
||||
var bind nativeBind
|
||||
var newPort uint16
|
||||
|
||||
bind.netlinkSock, err = createNetlinkRouteSocket()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
|
||||
if err != nil {
|
||||
unix.Close(bind.netlinkSock)
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
go bind.routineRouteListener(device)
|
||||
|
||||
// attempt ipv6 bind, update port if successful
|
||||
|
||||
// Attempt ipv6 bind, update port if successful.
|
||||
bind.sock6, newPort, err = create6(port)
|
||||
if err != nil {
|
||||
if err != syscall.EAFNOSUPPORT {
|
||||
bind.netlinkCancel.Cancel()
|
||||
return nil, 0, err
|
||||
}
|
||||
} else {
|
||||
port = newPort
|
||||
}
|
||||
|
||||
// attempt ipv4 bind, update port if successful
|
||||
|
||||
// Attempt ipv4 bind, update port if successful.
|
||||
bind.sock4, newPort, err = create4(port)
|
||||
if err != nil {
|
||||
if err != syscall.EAFNOSUPPORT {
|
||||
bind.netlinkCancel.Cancel()
|
||||
unix.Close(bind.sock6)
|
||||
return nil, 0, err
|
||||
}
|
||||
@@ -178,6 +136,10 @@ func CreateBind(port uint16, device *Device) (*nativeBind, uint16, error) {
|
||||
return &bind, port, nil
|
||||
}
|
||||
|
||||
func (bind *nativeBind) LastMark() uint32 {
|
||||
return bind.lastMark
|
||||
}
|
||||
|
||||
func (bind *nativeBind) SetMark(value uint32) error {
|
||||
if bind.sock6 != -1 {
|
||||
err := unix.SetsockoptInt(
|
||||
@@ -216,22 +178,18 @@ func closeUnblock(fd int) error {
|
||||
}
|
||||
|
||||
func (bind *nativeBind) Close() error {
|
||||
var err1, err2, err3 error
|
||||
var err1, err2 error
|
||||
if bind.sock6 != -1 {
|
||||
err1 = closeUnblock(bind.sock6)
|
||||
}
|
||||
if bind.sock4 != -1 {
|
||||
err2 = closeUnblock(bind.sock4)
|
||||
}
|
||||
err3 = bind.netlinkCancel.Cancel()
|
||||
|
||||
if err1 != nil {
|
||||
return err1
|
||||
}
|
||||
if err2 != nil {
|
||||
return err2
|
||||
}
|
||||
return err3
|
||||
}
|
||||
|
||||
func (bind *nativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||
@@ -278,10 +236,10 @@ func (bind *nativeBind) Send(buff []byte, end Endpoint) error {
|
||||
func (end *NativeEndpoint) SrcIP() net.IP {
|
||||
if !end.isV6 {
|
||||
return net.IPv4(
|
||||
end.src4().src[0],
|
||||
end.src4().src[1],
|
||||
end.src4().src[2],
|
||||
end.src4().src[3],
|
||||
end.src4().Src[0],
|
||||
end.src4().Src[1],
|
||||
end.src4().Src[2],
|
||||
end.src4().Src[3],
|
||||
)
|
||||
} else {
|
||||
return end.src6().src[:]
|
||||
@@ -478,8 +436,8 @@ func send4(sock int, end *NativeEndpoint, buff []byte) error {
|
||||
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
|
||||
},
|
||||
unix.Inet4Pktinfo{
|
||||
Spec_dst: end.src4().src,
|
||||
Ifindex: end.src4().ifindex,
|
||||
Spec_dst: end.src4().Src,
|
||||
Ifindex: end.src4().Ifindex,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -573,8 +531,8 @@ func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
|
||||
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
|
||||
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
|
||||
end.src4().src = cmsg.pktinfo.Spec_dst
|
||||
end.src4().ifindex = cmsg.pktinfo.Ifindex
|
||||
end.src4().Src = cmsg.pktinfo.Spec_dst
|
||||
end.src4().Ifindex = cmsg.pktinfo.Ifindex
|
||||
}
|
||||
|
||||
return size, nil
|
||||
@@ -611,156 +569,3 @@ func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
|
||||
|
||||
return size, nil
|
||||
}
|
||||
|
||||
func (bind *nativeBind) routineRouteListener(device *Device) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *Endpoint
|
||||
}
|
||||
var reqPeer map[uint32]peerEndpointPtr
|
||||
var reqPeerLock sync.Mutex
|
||||
|
||||
defer unix.Close(bind.netlinkSock)
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !bind.netlinkCancel.ReadyRead() {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if uint(hdr.Len) > uint(len(remain)) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||
if uint(len(remain)) < uint(hdr.Len) {
|
||||
break
|
||||
}
|
||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||
for {
|
||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||
break
|
||||
}
|
||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||
break
|
||||
}
|
||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||
reqPeerLock.Lock()
|
||||
if reqPeer == nil {
|
||||
reqPeerLock.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr, ok := reqPeer[hdr.Seq]
|
||||
reqPeerLock.Unlock()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.Lock()
|
||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
|
||||
pePtr.peer.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
reqPeerLock.Lock()
|
||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||
reqPeerLock.Unlock()
|
||||
go func() {
|
||||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.RLock()
|
||||
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
|
||||
peer.RUnlock()
|
||||
continue
|
||||
}
|
||||
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
|
||||
peer.RUnlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
hdr unix.NlMsghdr
|
||||
msg unix.RtMsg
|
||||
dsthdr unix.RtAttr
|
||||
dst [4]byte
|
||||
srchdr unix.RtAttr
|
||||
src [4]byte
|
||||
markhdr unix.RtAttr
|
||||
mark uint32
|
||||
}{
|
||||
unix.NlMsghdr{
|
||||
Type: uint16(unix.RTM_GETROUTE),
|
||||
Flags: unix.NLM_F_REQUEST,
|
||||
Seq: i,
|
||||
},
|
||||
unix.RtMsg{
|
||||
Family: unix.AF_INET,
|
||||
Dst_len: 32,
|
||||
Src_len: 32,
|
||||
},
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
peer.endpoint.(*NativeEndpoint).dst4().Addr,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
peer.endpoint.(*NativeEndpoint).src4().src,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
},
|
||||
uint32(bind.lastMark),
|
||||
}
|
||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.RUnlock()
|
||||
i++
|
||||
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}()
|
||||
}
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
package conn
|
||||
|
||||
func (bind *nativeBind) SetMark(mark uint32) error {
|
||||
return nil
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
package conn
|
||||
|
||||
import (
|
||||
"runtime"
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,15 +1,19 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "errors"
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type DummyDatagram struct {
|
||||
msg []byte
|
||||
endpoint Endpoint
|
||||
endpoint conn.Endpoint
|
||||
world bool // better type
|
||||
}
|
||||
|
||||
@@ -25,7 +29,7 @@ func (b *DummyBind) SetMark(v uint32) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, conn.Endpoint, error) {
|
||||
datagram, ok := <-b.in6
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
@@ -34,7 +38,7 @@ func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
|
||||
return len(datagram.msg), datagram.endpoint, nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
|
||||
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, conn.Endpoint, error) {
|
||||
datagram, ok := <-b.in4
|
||||
if !ok {
|
||||
return 0, nil, errors.New("closed")
|
||||
@@ -50,6 +54,6 @@ func (b *DummyBind) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
|
||||
func (b *DummyBind) Send(buff []byte, end conn.Endpoint) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import "errors"
|
||||
|
||||
func (device *Device) PeekLookAtSocketFd4() (fd int, err error) {
|
||||
nb, ok := device.net.bind.(*nativeBind)
|
||||
if !ok {
|
||||
return 0, errors.New("no socket exists")
|
||||
}
|
||||
sysconn, err := nb.ipv4.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = sysconn.Control(func(f uintptr) {
|
||||
fd = int(f)
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (device *Device) PeekLookAtSocketFd6() (fd int, err error) {
|
||||
nb, ok := device.net.bind.(*nativeBind)
|
||||
if !ok {
|
||||
return 0, errors.New("no socket exists")
|
||||
}
|
||||
sysconn, err := nb.ipv6.SyscallConn()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = sysconn.Control(func(f uintptr) {
|
||||
fd = int(f)
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
187
device/conn.go
187
device/conn.go
@@ -1,187 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
)
|
||||
|
||||
const (
|
||||
ConnRoutineNumber = 2
|
||||
)
|
||||
|
||||
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
|
||||
*/
|
||||
type Bind interface {
|
||||
SetMark(value uint32) error
|
||||
ReceiveIPv6(buff []byte) (int, Endpoint, error)
|
||||
ReceiveIPv4(buff []byte) (int, Endpoint, error)
|
||||
Send(buff []byte, end Endpoint) error
|
||||
Close() error
|
||||
}
|
||||
|
||||
/* An Endpoint maintains the source/destination caching for a peer
|
||||
*
|
||||
* dst : the remote address of a peer ("endpoint" in uapi terminology)
|
||||
* src : the local address from which datagrams originate going to the peer
|
||||
*/
|
||||
type Endpoint interface {
|
||||
ClearSrc() // clears the source address
|
||||
SrcToString() string // returns the local source address (ip:port)
|
||||
DstToString() string // returns the destination address (ip:port)
|
||||
DstToBytes() []byte // used for mac2 cookie calculations
|
||||
DstIP() net.IP
|
||||
SrcIP() net.IP
|
||||
}
|
||||
|
||||
func parseEndpoint(s string) (*net.UDPAddr, error) {
|
||||
// ensure that the host is an IP address
|
||||
|
||||
host, _, err := net.SplitHostPort(s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if i := strings.LastIndexByte(host, '%'); i > 0 && strings.IndexByte(host, ':') >= 0 {
|
||||
// Remove the scope, if any. ResolveUDPAddr below will use it, but here we're just
|
||||
// trying to make sure with a small sanity test that this is a real IP address and
|
||||
// not something that's likely to incur DNS lookups.
|
||||
host = host[:i]
|
||||
}
|
||||
if ip := net.ParseIP(host); ip == nil {
|
||||
return nil, errors.New("Failed to parse IP address: " + host)
|
||||
}
|
||||
|
||||
// parse address and port
|
||||
|
||||
addr, err := net.ResolveUDPAddr("udp", s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ip4 := addr.IP.To4()
|
||||
if ip4 != nil {
|
||||
addr.IP = ip4
|
||||
}
|
||||
return addr, err
|
||||
}
|
||||
|
||||
func unsafeCloseBind(device *Device) error {
|
||||
var err error
|
||||
netc := &device.net
|
||||
if netc.bind != nil {
|
||||
err = netc.bind.Close()
|
||||
netc.bind = nil
|
||||
}
|
||||
netc.stopping.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) BindSetMark(mark uint32) error {
|
||||
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// check if modified
|
||||
|
||||
if device.net.fwmark == mark {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update fwmark on existing bind
|
||||
|
||||
device.net.fwmark = mark
|
||||
if device.isUp.Get() && device.net.bind != nil {
|
||||
if err := device.net.bind.SetMark(mark); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindUpdate() error {
|
||||
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// close existing sockets
|
||||
|
||||
if err := unsafeCloseBind(device); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new sockets
|
||||
|
||||
if device.isUp.Get() {
|
||||
|
||||
// bind to new port
|
||||
|
||||
var err error
|
||||
netc := &device.net
|
||||
netc.bind, netc.port, err = CreateBind(netc.port, device)
|
||||
if err != nil {
|
||||
netc.bind = nil
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
// set fwmark
|
||||
|
||||
if netc.fwmark != 0 {
|
||||
err = netc.bind.SetMark(netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
// start receiving routines
|
||||
|
||||
device.net.starting.Add(ConnRoutineNumber)
|
||||
device.net.stopping.Add(ConnRoutineNumber)
|
||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||
device.net.starting.Wait()
|
||||
|
||||
device.log.Debug.Println("UDP bind has been updated")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindClose() error {
|
||||
device.net.Lock()
|
||||
err := unsafeCloseBind(device)
|
||||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -13,7 +13,7 @@ import (
|
||||
|
||||
const (
|
||||
RekeyAfterMessages = (1 << 60)
|
||||
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
||||
RejectAfterMessages = (1 << 64) - (1 << 13) - 1
|
||||
RekeyAfterTime = time.Second * 120
|
||||
RekeyAttemptTime = time.Second * 90
|
||||
RekeyTimeout = time.Second * 5
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
166
device/device.go
166
device/device.go
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -11,15 +11,14 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
const (
|
||||
DeviceRoutineNumberPerCPU = 3
|
||||
DeviceRoutineNumberAdditional = 2
|
||||
)
|
||||
|
||||
type Device struct {
|
||||
isUp AtomicBool // device is (going) up
|
||||
isClosed AtomicBool // device is closed? (acting as guard)
|
||||
@@ -39,7 +38,8 @@ type Device struct {
|
||||
starting sync.WaitGroup
|
||||
stopping sync.WaitGroup
|
||||
sync.RWMutex
|
||||
bind Bind // bind interface
|
||||
bind conn.Bind // bind interface
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
port uint16 // listening port
|
||||
fwmark uint32 // mark value (0 = disabled)
|
||||
}
|
||||
@@ -236,24 +236,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||
|
||||
// do static-static DH pre-computations
|
||||
|
||||
rmKey := device.staticIdentity.privateKey.IsZero()
|
||||
|
||||
expiredPeers := make([]*Peer, 0, len(device.peers.keyMap))
|
||||
for key, peer := range device.peers.keyMap {
|
||||
for _, peer := range device.peers.keyMap {
|
||||
handshake := &peer.handshake
|
||||
|
||||
if rmKey {
|
||||
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
|
||||
} else {
|
||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
|
||||
}
|
||||
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
unsafeRemovePeer(device, peer, key)
|
||||
} else {
|
||||
expiredPeers = append(expiredPeers, peer)
|
||||
}
|
||||
}
|
||||
|
||||
for _, peer := range lockedPeers {
|
||||
peer.handshake.mutex.RUnlock()
|
||||
@@ -311,14 +299,16 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||
cpus := runtime.NumCPU()
|
||||
device.state.starting.Wait()
|
||||
device.state.stopping.Wait()
|
||||
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
||||
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
|
||||
for i := 0; i < cpus; i += 1 {
|
||||
device.state.starting.Add(3)
|
||||
device.state.stopping.Add(3)
|
||||
go device.RoutineEncryption()
|
||||
go device.RoutineDecryption()
|
||||
go device.RoutineHandshake()
|
||||
}
|
||||
|
||||
device.state.starting.Add(2)
|
||||
device.state.stopping.Add(2)
|
||||
go device.RoutineReadFromTUN()
|
||||
go device.RoutineTUNEventReader()
|
||||
|
||||
@@ -393,10 +383,10 @@ func (device *Device) Close() {
|
||||
device.isUp.Set(false)
|
||||
|
||||
close(device.signals.stop)
|
||||
device.state.stopping.Wait()
|
||||
|
||||
device.RemoveAllPeers()
|
||||
|
||||
device.state.stopping.Wait()
|
||||
device.FlushPacketQueues()
|
||||
|
||||
device.rate.limiter.Close()
|
||||
@@ -425,3 +415,133 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
|
||||
func unsafeCloseBind(device *Device) error {
|
||||
var err error
|
||||
netc := &device.net
|
||||
if netc.netlinkCancel != nil {
|
||||
netc.netlinkCancel.Cancel()
|
||||
}
|
||||
if netc.bind != nil {
|
||||
err = netc.bind.Close()
|
||||
netc.bind = nil
|
||||
}
|
||||
netc.stopping.Wait()
|
||||
return err
|
||||
}
|
||||
|
||||
func (device *Device) Bind() conn.Bind {
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
return device.net.bind
|
||||
}
|
||||
|
||||
func (device *Device) BindSetMark(mark uint32) error {
|
||||
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// check if modified
|
||||
|
||||
if device.net.fwmark == mark {
|
||||
return nil
|
||||
}
|
||||
|
||||
// update fwmark on existing bind
|
||||
|
||||
device.net.fwmark = mark
|
||||
if device.isUp.Get() && device.net.bind != nil {
|
||||
if err := device.net.bind.SetMark(mark); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindUpdate() error {
|
||||
|
||||
device.net.Lock()
|
||||
defer device.net.Unlock()
|
||||
|
||||
// close existing sockets
|
||||
|
||||
if err := unsafeCloseBind(device); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open new sockets
|
||||
|
||||
if device.isUp.Get() {
|
||||
|
||||
// bind to new port
|
||||
|
||||
var err error
|
||||
netc := &device.net
|
||||
netc.bind, netc.port, err = conn.CreateBind(netc.port)
|
||||
if err != nil {
|
||||
netc.bind = nil
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
netc.bind = nil
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
// set fwmark
|
||||
|
||||
if netc.fwmark != 0 {
|
||||
err = netc.bind.SetMark(netc.fwmark)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
// clear cached source addresses
|
||||
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
if peer.endpoint != nil {
|
||||
peer.endpoint.ClearSrc()
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
|
||||
// start receiving routines
|
||||
|
||||
device.net.starting.Add(2)
|
||||
device.net.stopping.Add(2)
|
||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||
device.net.starting.Wait()
|
||||
|
||||
device.log.Debug.Println("UDP bind has been updated")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) BindClose() error {
|
||||
device.net.Lock()
|
||||
err := unsafeCloseBind(device)
|
||||
device.net.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -8,28 +8,40 @@ package device
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
)
|
||||
|
||||
func getFreePort(t *testing.T) string {
|
||||
l, err := net.ListenPacket("udp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer l.Close()
|
||||
return fmt.Sprintf("%d", l.LocalAddr().(*net.UDPAddr).Port)
|
||||
}
|
||||
|
||||
func TestTwoDevicePing(t *testing.T) {
|
||||
// TODO(crawshaw): pick unused ports on localhost
|
||||
port1 := getFreePort(t)
|
||||
port2 := getFreePort(t)
|
||||
|
||||
cfg1 := `private_key=481eb0d8113a4a5da532d2c3e9c14b53c8454b34ab109676f6b58c2245e37b58
|
||||
listen_port=53511
|
||||
listen_port={{PORT1}}
|
||||
replace_peers=true
|
||||
public_key=f70dbb6b1b92a1dde1c783b297016af3f572fef13b0abb16a2623d89a58e9725
|
||||
protocol_version=1
|
||||
replace_allowed_ips=true
|
||||
allowed_ip=1.0.0.2/32
|
||||
endpoint=127.0.0.1:53512`
|
||||
tun1 := NewChannelTUN()
|
||||
endpoint=127.0.0.1:{{PORT2}}`
|
||||
cfg1 = strings.ReplaceAll(cfg1, "{{PORT1}}", port1)
|
||||
cfg1 = strings.ReplaceAll(cfg1, "{{PORT2}}", port2)
|
||||
|
||||
tun1 := tuntest.NewChannelTUN()
|
||||
dev1 := NewDevice(tun1.TUN(), NewLogger(LogLevelDebug, "dev1: "))
|
||||
dev1.Up()
|
||||
defer dev1.Close()
|
||||
@@ -38,14 +50,17 @@ endpoint=127.0.0.1:53512`
|
||||
}
|
||||
|
||||
cfg2 := `private_key=98c7989b1661a0d64fd6af3502000f87716b7c4bbcf00d04fc6073aa7b539768
|
||||
listen_port=53512
|
||||
listen_port={{PORT2}}
|
||||
replace_peers=true
|
||||
public_key=49e80929259cebdda4f322d6d2b1a6fad819d603acd26fd5d845e7a123036427
|
||||
protocol_version=1
|
||||
replace_allowed_ips=true
|
||||
allowed_ip=1.0.0.1/32
|
||||
endpoint=127.0.0.1:53511`
|
||||
tun2 := NewChannelTUN()
|
||||
endpoint=127.0.0.1:{{PORT1}}`
|
||||
cfg2 = strings.ReplaceAll(cfg2, "{{PORT1}}", port1)
|
||||
cfg2 = strings.ReplaceAll(cfg2, "{{PORT2}}", port2)
|
||||
|
||||
tun2 := tuntest.NewChannelTUN()
|
||||
dev2 := NewDevice(tun2.TUN(), NewLogger(LogLevelDebug, "dev2: "))
|
||||
dev2.Up()
|
||||
defer dev2.Close()
|
||||
@@ -54,7 +69,7 @@ endpoint=127.0.0.1:53511`
|
||||
}
|
||||
|
||||
t.Run("ping 1.0.0.1", func(t *testing.T) {
|
||||
msg2to1 := ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
||||
msg2to1 := tuntest.Ping(net.ParseIP("1.0.0.1"), net.ParseIP("1.0.0.2"))
|
||||
tun2.Outbound <- msg2to1
|
||||
select {
|
||||
case msgRecv := <-tun1.Inbound:
|
||||
@@ -67,7 +82,7 @@ endpoint=127.0.0.1:53511`
|
||||
})
|
||||
|
||||
t.Run("ping 1.0.0.2", func(t *testing.T) {
|
||||
msg1to2 := ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
||||
msg1to2 := tuntest.Ping(net.ParseIP("1.0.0.2"), net.ParseIP("1.0.0.1"))
|
||||
tun1.Outbound <- msg1to2
|
||||
select {
|
||||
case msgRecv := <-tun2.Inbound:
|
||||
@@ -80,139 +95,6 @@ endpoint=127.0.0.1:53511`
|
||||
})
|
||||
}
|
||||
|
||||
func ping(dst, src net.IP) []byte {
|
||||
localPort := uint16(1337)
|
||||
seq := uint16(0)
|
||||
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(payload[0:], localPort)
|
||||
binary.BigEndian.PutUint16(payload[2:], seq)
|
||||
|
||||
return genICMPv4(payload, dst, src)
|
||||
}
|
||||
|
||||
// checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
||||
func checksum(buf []byte, initial uint16) uint16 {
|
||||
v := uint32(initial)
|
||||
for i := 0; i < len(buf)-1; i += 2 {
|
||||
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||
}
|
||||
if len(buf)%2 == 1 {
|
||||
v += uint32(buf[len(buf)-1]) << 8
|
||||
}
|
||||
for v > 0xffff {
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
}
|
||||
return ^uint16(v)
|
||||
}
|
||||
|
||||
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
||||
const (
|
||||
icmpv4ProtocolNumber = 1
|
||||
icmpv4Echo = 8
|
||||
icmpv4ChecksumOffset = 2
|
||||
icmpv4Size = 8
|
||||
ipv4Size = 20
|
||||
ipv4TotalLenOffset = 2
|
||||
ipv4ChecksumOffset = 10
|
||||
ttl = 65
|
||||
)
|
||||
|
||||
hdr := make([]byte, ipv4Size+icmpv4Size)
|
||||
|
||||
ip := hdr[0:ipv4Size]
|
||||
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
||||
|
||||
// https://tools.ietf.org/html/rfc792
|
||||
icmpv4[0] = icmpv4Echo // type
|
||||
icmpv4[1] = 0 // code
|
||||
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
||||
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
||||
|
||||
// https://tools.ietf.org/html/rfc760 section 3.1
|
||||
length := uint16(len(hdr) + len(payload))
|
||||
ip[0] = (4 << 4) | (ipv4Size / 4)
|
||||
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||
ip[8] = ttl
|
||||
ip[9] = icmpv4ProtocolNumber
|
||||
copy(ip[12:], src.To4())
|
||||
copy(ip[16:], dst.To4())
|
||||
chksum = ^checksum(ip[:], 0)
|
||||
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||
|
||||
var v []byte
|
||||
v = append(v, hdr...)
|
||||
v = append(v, payload...)
|
||||
return []byte(v)
|
||||
}
|
||||
|
||||
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
||||
type ChannelTUN struct {
|
||||
Inbound chan []byte // incoming packets, closed on TUN close
|
||||
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||
|
||||
closed chan struct{}
|
||||
events chan tun.Event
|
||||
tun chTun
|
||||
}
|
||||
|
||||
func NewChannelTUN() *ChannelTUN {
|
||||
c := &ChannelTUN{
|
||||
Inbound: make(chan []byte),
|
||||
Outbound: make(chan []byte),
|
||||
closed: make(chan struct{}),
|
||||
events: make(chan tun.Event, 1),
|
||||
}
|
||||
c.tun.c = c
|
||||
c.events <- tun.EventUp
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ChannelTUN) TUN() tun.Device {
|
||||
return &c.tun
|
||||
}
|
||||
|
||||
type chTun struct {
|
||||
c *ChannelTUN
|
||||
}
|
||||
|
||||
func (t *chTun) File() *os.File { return nil }
|
||||
|
||||
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||
case msg := <-t.c.Outbound:
|
||||
return copy(data[offset:], msg), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write is called by the wireguard device to deliver a packet for routing.
|
||||
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||
if offset == -1 {
|
||||
close(t.c.closed)
|
||||
close(t.c.events)
|
||||
return 0, io.EOF
|
||||
}
|
||||
msg := make([]byte, len(data)-offset)
|
||||
copy(msg, data[offset:])
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||
case t.c.Inbound <- msg:
|
||||
return len(data) - offset, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (t *chTun) Flush() error { return nil }
|
||||
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
||||
func (t *chTun) Close() error {
|
||||
t.Write(nil, -1)
|
||||
return nil
|
||||
}
|
||||
|
||||
func assertNil(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/binary"
|
||||
"sync"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
type IndexTableEntry struct {
|
||||
@@ -25,7 +25,8 @@ type IndexTable struct {
|
||||
func randUint32() (uint32, error) {
|
||||
var integer [4]byte
|
||||
_, err := rand.Read(integer[:])
|
||||
return *(*uint32)(unsafe.Pointer(&integer[0])), err
|
||||
// Arbitrary endianness; both are intrinsified by the Go compiler.
|
||||
return binary.LittleEndian.Uint32(integer[:]), err
|
||||
}
|
||||
|
||||
func (table *IndexTable) Init() {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -8,7 +8,9 @@ package device
|
||||
import (
|
||||
"crypto/cipher"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/wireguard/replay"
|
||||
)
|
||||
@@ -24,7 +26,7 @@ type Keypair struct {
|
||||
sendNonce uint64
|
||||
send cipher.AEAD
|
||||
receive cipher.AEAD
|
||||
replayFilter replay.ReplayFilter
|
||||
replayFilter replay.Filter
|
||||
isInitiator bool
|
||||
created time.Time
|
||||
localIndex uint32
|
||||
@@ -38,6 +40,14 @@ type Keypairs struct {
|
||||
next *Keypair
|
||||
}
|
||||
|
||||
func (kp *Keypairs) storeNext(next *Keypair) {
|
||||
atomic.StorePointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next)), (unsafe.Pointer)(next))
|
||||
}
|
||||
|
||||
func (kp *Keypairs) loadNext() *Keypair {
|
||||
return (*Keypair)(atomic.LoadPointer((*unsafe.Pointer)((unsafe.Pointer)(&kp.next))))
|
||||
}
|
||||
|
||||
func (kp *Keypairs) Current() *Keypair {
|
||||
kp.RLock()
|
||||
defer kp.RUnlock()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
16
device/mobilequirks.go
Normal file
16
device/mobilequirks.go
Normal file
@@ -0,0 +1,16 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
|
||||
device.peers.RLock()
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
peer.disableRoaming = peer.endpoint != nil
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,29 +1,51 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/crypto/poly1305"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tai64n"
|
||||
)
|
||||
|
||||
type handshakeState int
|
||||
|
||||
// TODO(crawshaw): add commentary describing each state and the transitions
|
||||
const (
|
||||
HandshakeZeroed = iota
|
||||
HandshakeInitiationCreated
|
||||
HandshakeInitiationConsumed
|
||||
HandshakeResponseCreated
|
||||
HandshakeResponseConsumed
|
||||
handshakeZeroed = handshakeState(iota)
|
||||
handshakeInitiationCreated
|
||||
handshakeInitiationConsumed
|
||||
handshakeResponseCreated
|
||||
handshakeResponseConsumed
|
||||
)
|
||||
|
||||
func (hs handshakeState) String() string {
|
||||
switch hs {
|
||||
case handshakeZeroed:
|
||||
return "handshakeZeroed"
|
||||
case handshakeInitiationCreated:
|
||||
return "handshakeInitiationCreated"
|
||||
case handshakeInitiationConsumed:
|
||||
return "handshakeInitiationConsumed"
|
||||
case handshakeResponseCreated:
|
||||
return "handshakeResponseCreated"
|
||||
case handshakeResponseConsumed:
|
||||
return "handshakeResponseConsumed"
|
||||
default:
|
||||
return fmt.Sprintf("Handshake(UNKNOWN:%d)", int(hs))
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
|
||||
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
|
||||
@@ -95,7 +117,7 @@ type MessageCookieReply struct {
|
||||
}
|
||||
|
||||
type Handshake struct {
|
||||
state int
|
||||
state handshakeState
|
||||
mutex sync.RWMutex
|
||||
hash [blake2s.Size]byte // hash value
|
||||
chainKey [blake2s.Size]byte // chain key
|
||||
@@ -135,7 +157,7 @@ func (h *Handshake) Clear() {
|
||||
setZero(h.chainKey[:])
|
||||
setZero(h.hash[:])
|
||||
h.localIndex = 0
|
||||
h.state = HandshakeZeroed
|
||||
h.state = handshakeZeroed
|
||||
}
|
||||
|
||||
func (h *Handshake) mixHash(data []byte) {
|
||||
@@ -154,6 +176,7 @@ func init() {
|
||||
}
|
||||
|
||||
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
|
||||
var errZeroECDHResult = errors.New("ECDH returned all zeros")
|
||||
|
||||
device.staticIdentity.RLock()
|
||||
defer device.staticIdentity.RUnlock()
|
||||
@@ -162,12 +185,7 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
return nil, errors.New("static shared secret is zero")
|
||||
}
|
||||
|
||||
// create ephemeral key
|
||||
|
||||
var err error
|
||||
handshake.hash = InitialHash
|
||||
handshake.chainKey = InitialChainKey
|
||||
@@ -176,31 +194,22 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// assign index
|
||||
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
handshake.mixHash(handshake.remoteStatic[:])
|
||||
|
||||
msg := MessageInitiation{
|
||||
Type: MessageInitiationType,
|
||||
Ephemeral: handshake.localEphemeral.publicKey(),
|
||||
Sender: handshake.localIndex,
|
||||
}
|
||||
|
||||
handshake.mixKey(msg.Ephemeral[:])
|
||||
handshake.mixHash(msg.Ephemeral[:])
|
||||
|
||||
// encrypt static key
|
||||
|
||||
func() {
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
|
||||
if isZero(ss[:]) {
|
||||
return nil, errZeroECDHResult
|
||||
}
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
&key,
|
||||
@@ -209,26 +218,32 @@ func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, e
|
||||
)
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
|
||||
}()
|
||||
handshake.mixHash(msg.Static[:])
|
||||
|
||||
// encrypt timestamp
|
||||
|
||||
timestamp := tai64n.Now()
|
||||
func() {
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
return nil, errZeroECDHResult
|
||||
}
|
||||
KDF2(
|
||||
&handshake.chainKey,
|
||||
&key,
|
||||
handshake.chainKey[:],
|
||||
handshake.precomputedStaticStatic[:],
|
||||
)
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
timestamp := tai64n.Now()
|
||||
aead, _ = chacha20poly1305.New(key[:])
|
||||
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
|
||||
}()
|
||||
|
||||
// assign index
|
||||
device.indexTable.Delete(handshake.localIndex)
|
||||
msg.Sender, err = device.indexTable.NewIndexForHandshake(peer, handshake)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
handshake.localIndex = msg.Sender
|
||||
|
||||
handshake.mixHash(msg.Timestamp[:])
|
||||
handshake.state = HandshakeInitiationCreated
|
||||
handshake.state = handshakeInitiationCreated
|
||||
return &msg, nil
|
||||
}
|
||||
|
||||
@@ -250,16 +265,16 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
|
||||
|
||||
// decrypt static key
|
||||
|
||||
var err error
|
||||
var peerPK NoisePublicKey
|
||||
func() {
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
|
||||
if isZero(ss[:]) {
|
||||
return nil
|
||||
}
|
||||
KDF2(&chainKey, &key, chainKey[:], ss[:])
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
|
||||
}()
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
@@ -273,23 +288,24 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
}
|
||||
|
||||
handshake := &peer.handshake
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// verify identity
|
||||
|
||||
var timestamp tai64n.Timestamp
|
||||
var key [chacha20poly1305.KeySize]byte
|
||||
|
||||
handshake.mutex.RLock()
|
||||
|
||||
if isZero(handshake.precomputedStaticStatic[:]) {
|
||||
handshake.mutex.RUnlock()
|
||||
return nil
|
||||
}
|
||||
KDF2(
|
||||
&chainKey,
|
||||
&key,
|
||||
chainKey[:],
|
||||
handshake.precomputedStaticStatic[:],
|
||||
)
|
||||
aead, _ := chacha20poly1305.New(key[:])
|
||||
aead, _ = chacha20poly1305.New(key[:])
|
||||
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
|
||||
if err != nil {
|
||||
handshake.mutex.RUnlock()
|
||||
@@ -299,11 +315,15 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
|
||||
// protect against replay & flood
|
||||
|
||||
var ok bool
|
||||
ok = timestamp.After(handshake.lastTimestamp)
|
||||
ok = ok && time.Since(handshake.lastInitiationConsumption) > HandshakeInitationRate
|
||||
replay := !timestamp.After(handshake.lastTimestamp)
|
||||
flood := time.Since(handshake.lastInitiationConsumption) <= HandshakeInitationRate
|
||||
handshake.mutex.RUnlock()
|
||||
if !ok {
|
||||
if replay {
|
||||
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake replay @ %v\n", peer, timestamp)
|
||||
return nil
|
||||
}
|
||||
if flood {
|
||||
device.log.Debug.Printf("%v - ConsumeMessageInitiation: handshake flood\n", peer)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -322,7 +342,7 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
|
||||
if now.After(handshake.lastInitiationConsumption) {
|
||||
handshake.lastInitiationConsumption = now
|
||||
}
|
||||
handshake.state = HandshakeInitiationConsumed
|
||||
handshake.state = handshakeInitiationConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
@@ -337,7 +357,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
handshake.mutex.Lock()
|
||||
defer handshake.mutex.Unlock()
|
||||
|
||||
if handshake.state != HandshakeInitiationConsumed {
|
||||
if handshake.state != handshakeInitiationConsumed {
|
||||
return nil, errors.New("handshake initiation must be consumed first")
|
||||
}
|
||||
|
||||
@@ -393,7 +413,7 @@ func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error
|
||||
handshake.mixHash(msg.Empty[:])
|
||||
}()
|
||||
|
||||
handshake.state = HandshakeResponseCreated
|
||||
handshake.state = handshakeResponseCreated
|
||||
|
||||
return &msg, nil
|
||||
}
|
||||
@@ -423,7 +443,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
handshake.mutex.RLock()
|
||||
defer handshake.mutex.RUnlock()
|
||||
|
||||
if handshake.state != HandshakeInitiationCreated {
|
||||
if handshake.state != handshakeInitiationCreated {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -484,7 +504,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
|
||||
handshake.hash = hash
|
||||
handshake.chainKey = chainKey
|
||||
handshake.remoteIndex = msg.Sender
|
||||
handshake.state = HandshakeResponseConsumed
|
||||
handshake.state = handshakeResponseConsumed
|
||||
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
@@ -509,7 +529,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
var sendKey [chacha20poly1305.KeySize]byte
|
||||
var recvKey [chacha20poly1305.KeySize]byte
|
||||
|
||||
if handshake.state == HandshakeResponseConsumed {
|
||||
if handshake.state == handshakeResponseConsumed {
|
||||
KDF2(
|
||||
&sendKey,
|
||||
&recvKey,
|
||||
@@ -517,7 +537,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
nil,
|
||||
)
|
||||
isInitiator = true
|
||||
} else if handshake.state == HandshakeResponseCreated {
|
||||
} else if handshake.state == handshakeResponseCreated {
|
||||
KDF2(
|
||||
&recvKey,
|
||||
&sendKey,
|
||||
@@ -526,7 +546,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
)
|
||||
isInitiator = false
|
||||
} else {
|
||||
return errors.New("invalid state for keypair derivation")
|
||||
return fmt.Errorf("invalid state for keypair derivation: %v", handshake.state)
|
||||
}
|
||||
|
||||
// zero handshake
|
||||
@@ -534,7 +554,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
setZero(handshake.chainKey[:])
|
||||
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
|
||||
setZero(handshake.localEphemeral[:])
|
||||
peer.handshake.state = HandshakeZeroed
|
||||
peer.handshake.state = handshakeZeroed
|
||||
|
||||
// create AEAD instances
|
||||
|
||||
@@ -547,7 +567,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
|
||||
keypair.created = time.Now()
|
||||
keypair.sendNonce = 0
|
||||
keypair.replayFilter.Init()
|
||||
keypair.replayFilter.Reset()
|
||||
keypair.isInitiator = isInitiator
|
||||
keypair.localIndex = peer.handshake.localIndex
|
||||
keypair.remoteIndex = peer.handshake.remoteIndex
|
||||
@@ -564,12 +584,12 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
defer keypairs.Unlock()
|
||||
|
||||
previous := keypairs.previous
|
||||
next := keypairs.next
|
||||
next := keypairs.loadNext()
|
||||
current := keypairs.current
|
||||
|
||||
if isInitiator {
|
||||
if next != nil {
|
||||
keypairs.next = nil
|
||||
keypairs.storeNext(nil)
|
||||
keypairs.previous = next
|
||||
device.DeleteKeypair(current)
|
||||
} else {
|
||||
@@ -578,7 +598,7 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
device.DeleteKeypair(previous)
|
||||
keypairs.current = keypair
|
||||
} else {
|
||||
keypairs.next = keypair
|
||||
keypairs.storeNext(keypair)
|
||||
device.DeleteKeypair(next)
|
||||
keypairs.previous = nil
|
||||
device.DeleteKeypair(previous)
|
||||
@@ -589,18 +609,19 @@ func (peer *Peer) BeginSymmetricSession() error {
|
||||
|
||||
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
|
||||
keypairs := &peer.keypairs
|
||||
if keypairs.next != receivedKeypair {
|
||||
|
||||
if keypairs.loadNext() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
keypairs.Lock()
|
||||
defer keypairs.Unlock()
|
||||
if keypairs.next != receivedKeypair {
|
||||
if keypairs.loadNext() != receivedKeypair {
|
||||
return false
|
||||
}
|
||||
old := keypairs.previous
|
||||
keypairs.previous = keypairs.current
|
||||
peer.device.DeleteKeypair(old)
|
||||
keypairs.current = keypairs.next
|
||||
keypairs.next = nil
|
||||
keypairs.current = keypairs.loadNext()
|
||||
keypairs.storeNext(nil)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -52,6 +52,15 @@ func (key *NoisePrivateKey) FromHex(src string) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (key *NoisePrivateKey) FromMaybeZeroHex(src string) (err error) {
|
||||
err = loadExactHex(key[:], src)
|
||||
if key.IsZero() {
|
||||
return
|
||||
}
|
||||
key.clamp()
|
||||
return
|
||||
}
|
||||
|
||||
func (key NoisePrivateKey) ToHex() string {
|
||||
return hex.EncodeToString(key[:])
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -113,7 +113,7 @@ func TestNoiseHandshake(t *testing.T) {
|
||||
t.Fatal("failed to derive keypair for peer 2", err)
|
||||
}
|
||||
|
||||
key1 := peer1.keypairs.next
|
||||
key1 := peer1.keypairs.loadNext()
|
||||
key2 := peer2.keypairs.current
|
||||
|
||||
// encrypting / decryption test
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -12,6 +12,8 @@ import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -24,10 +26,15 @@ type Peer struct {
|
||||
keypairs Keypairs
|
||||
handshake Handshake
|
||||
device *Device
|
||||
endpoint Endpoint
|
||||
endpoint conn.Endpoint
|
||||
persistentKeepaliveInterval uint16
|
||||
disableRoaming bool
|
||||
|
||||
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
|
||||
// These fields are accessed with atomic operations, which must be
|
||||
// 64-bit aligned even on 32-bit platforms. Go guarantees that an
|
||||
// allocated struct will be 64-bit aligned. So we place
|
||||
// atomically-accessed fields up front, so that they can share in
|
||||
// this alignment before smaller fields throw it off.
|
||||
stats struct {
|
||||
txBytes uint64 // bytes send to peer (endpoint)
|
||||
rxBytes uint64 // bytes received from peer
|
||||
@@ -51,6 +58,7 @@ type Peer struct {
|
||||
}
|
||||
|
||||
queue struct {
|
||||
sync.RWMutex
|
||||
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
|
||||
outbound chan *QueueOutboundElement // sequential ordering of work
|
||||
inbound chan *QueueInboundElement // sequential ordering of work
|
||||
@@ -108,7 +116,6 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
handshake := &peer.handshake
|
||||
handshake.mutex.Lock()
|
||||
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
|
||||
ssIsZero := isZero(handshake.precomputedStaticStatic[:])
|
||||
handshake.remoteStatic = pk
|
||||
handshake.mutex.Unlock()
|
||||
|
||||
@@ -116,13 +123,9 @@ func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
|
||||
|
||||
peer.endpoint = nil
|
||||
|
||||
// conditionally add
|
||||
// add
|
||||
|
||||
if !ssIsZero {
|
||||
device.peers.keyMap[pk] = peer
|
||||
} else {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// start peer
|
||||
|
||||
@@ -193,10 +196,11 @@ func (peer *Peer) Start() {
|
||||
peer.routines.stopping.Add(PeerRoutineNumber)
|
||||
|
||||
// prepare queues
|
||||
|
||||
peer.queue.Lock()
|
||||
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
|
||||
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
|
||||
peer.queue.Unlock()
|
||||
|
||||
peer.timersInit()
|
||||
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
|
||||
@@ -222,10 +226,10 @@ func (peer *Peer) ZeroAndFlushAll() {
|
||||
keypairs.Lock()
|
||||
device.DeleteKeypair(keypairs.previous)
|
||||
device.DeleteKeypair(keypairs.current)
|
||||
device.DeleteKeypair(keypairs.next)
|
||||
device.DeleteKeypair(keypairs.loadNext())
|
||||
keypairs.previous = nil
|
||||
keypairs.current = nil
|
||||
keypairs.next = nil
|
||||
keypairs.storeNext(nil)
|
||||
keypairs.Unlock()
|
||||
|
||||
// clear handshake state
|
||||
@@ -253,7 +257,7 @@ func (peer *Peer) ExpireCurrentKeypairs() {
|
||||
keypairs.current.sendNonce = RejectAfterMessages
|
||||
}
|
||||
if keypairs.next != nil {
|
||||
keypairs.next.sendNonce = RejectAfterMessages
|
||||
keypairs.loadNext().sendNonce = RejectAfterMessages
|
||||
}
|
||||
keypairs.Unlock()
|
||||
}
|
||||
@@ -282,17 +286,17 @@ func (peer *Peer) Stop() {
|
||||
|
||||
// close queues
|
||||
|
||||
peer.queue.Lock()
|
||||
close(peer.queue.nonce)
|
||||
close(peer.queue.outbound)
|
||||
close(peer.queue.inbound)
|
||||
peer.queue.Unlock()
|
||||
|
||||
peer.ZeroAndFlushAll()
|
||||
}
|
||||
|
||||
var RoamingDisabled bool
|
||||
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
|
||||
if RoamingDisabled {
|
||||
func (peer *Peer) SetEndpointFromPacket(endpoint conn.Endpoint) {
|
||||
if peer.disableRoaming {
|
||||
return
|
||||
}
|
||||
peer.Lock()
|
||||
|
||||
43
device/peer_test.go
Normal file
43
device/peer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func checkAlignment(t *testing.T, name string, offset uintptr) {
|
||||
t.Helper()
|
||||
if offset%8 != 0 {
|
||||
t.Errorf("offset of %q within struct is %d bytes, which does not align to 64-bit word boundaries (missing %d bytes). Atomic operations will crash on 32-bit systems.", name, offset, 8-(offset%8))
|
||||
}
|
||||
}
|
||||
|
||||
// TestPeerAlignment checks that atomically-accessed fields are
|
||||
// aligned to 64-bit boundaries, as required by the atomic package.
|
||||
//
|
||||
// Unfortunately, violating this rule on 32-bit platforms results in a
|
||||
// hard segfault at runtime.
|
||||
func TestPeerAlignment(t *testing.T) {
|
||||
var p Peer
|
||||
|
||||
typ := reflect.TypeOf(p)
|
||||
t.Logf("Peer type size: %d, with fields:", typ.Size())
|
||||
for i := 0; i < typ.NumField(); i++ {
|
||||
field := typ.Field(i)
|
||||
t.Logf("\t%30s\toffset=%3v\t(type size=%3d, align=%d)",
|
||||
field.Name,
|
||||
field.Offset,
|
||||
field.Type.Size(),
|
||||
field.Type.Align(),
|
||||
)
|
||||
}
|
||||
|
||||
checkAlignment(t, "Peer.stats", unsafe.Offsetof(p.stats))
|
||||
checkAlignment(t, "Peer.isRunning", unsafe.Offsetof(p.isRunning))
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -17,12 +17,13 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
type QueueHandshakeElement struct {
|
||||
msgType uint32
|
||||
packet []byte
|
||||
endpoint Endpoint
|
||||
endpoint conn.Endpoint
|
||||
buffer *[MaxMessageSize]byte
|
||||
}
|
||||
|
||||
@@ -33,7 +34,7 @@ type QueueInboundElement struct {
|
||||
packet []byte
|
||||
counter uint64
|
||||
keypair *Keypair
|
||||
endpoint Endpoint
|
||||
endpoint conn.Endpoint
|
||||
}
|
||||
|
||||
func (elem *QueueInboundElement) Drop() {
|
||||
@@ -90,7 +91,7 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
||||
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||
|
||||
logDebug := device.log.Debug
|
||||
defer func() {
|
||||
@@ -108,7 +109,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
||||
var (
|
||||
err error
|
||||
size int
|
||||
endpoint Endpoint
|
||||
endpoint conn.Endpoint
|
||||
)
|
||||
|
||||
for {
|
||||
@@ -183,11 +184,13 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
|
||||
|
||||
// add to decryption queues
|
||||
|
||||
peer.queue.RLock()
|
||||
if peer.isRunning.Get() {
|
||||
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
|
||||
buffer = device.GetMessageBuffer()
|
||||
}
|
||||
}
|
||||
peer.queue.RUnlock()
|
||||
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
@@ -107,6 +107,8 @@ func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement,
|
||||
/* Queues a keepalive if no packets are queued for peer
|
||||
*/
|
||||
func (peer *Peer) SendKeepalive() bool {
|
||||
peer.queue.RLock()
|
||||
defer peer.queue.RUnlock()
|
||||
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
|
||||
return false
|
||||
}
|
||||
@@ -310,6 +312,7 @@ func (device *Device) RoutineReadFromTUN() {
|
||||
|
||||
// insert into nonce/pre-handshake queue
|
||||
|
||||
peer.queue.RLock()
|
||||
if peer.isRunning.Get() {
|
||||
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
|
||||
peer.SendHandshakeInitiation(false)
|
||||
@@ -317,6 +320,7 @@ func (device *Device) RoutineReadFromTUN() {
|
||||
addToNonceQueue(peer.queue.nonce, elem, device)
|
||||
elem = nil
|
||||
}
|
||||
peer.queue.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -448,6 +452,21 @@ func (peer *Peer) RoutineNonce() {
|
||||
}
|
||||
}
|
||||
|
||||
func calculatePaddingSize(packetSize, mtu int) int {
|
||||
lastUnit := packetSize
|
||||
if mtu == 0 {
|
||||
return ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)) - lastUnit
|
||||
}
|
||||
if lastUnit > mtu {
|
||||
lastUnit %= mtu
|
||||
}
|
||||
paddedSize := ((lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1))
|
||||
if paddedSize > mtu {
|
||||
paddedSize = mtu
|
||||
}
|
||||
return paddedSize - lastUnit
|
||||
}
|
||||
|
||||
/* Encrypts the elements in the queue
|
||||
* and marks them for sequential consumption (by releasing the mutex)
|
||||
*
|
||||
@@ -514,13 +533,8 @@ func (device *Device) RoutineEncryption() {
|
||||
|
||||
// pad content to multiple of 16
|
||||
|
||||
mtu := int(atomic.LoadInt32(&device.tun.mtu))
|
||||
lastUnit := len(elem.packet) % mtu
|
||||
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
|
||||
if paddedSize > mtu {
|
||||
paddedSize = mtu
|
||||
}
|
||||
for i := len(elem.packet); i < paddedSize; i++ {
|
||||
paddingSize := calculatePaddingSize(len(elem.packet), int(atomic.LoadInt32(&device.tun.mtu)))
|
||||
for i := 0; i < paddingSize; i++ {
|
||||
elem.packet = append(elem.packet, 0)
|
||||
}
|
||||
|
||||
|
||||
12
device/sticky_default.go
Normal file
12
device/sticky_default.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// +build !linux android
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
return nil, nil
|
||||
}
|
||||
217
device/sticky_linux.go
Normal file
217
device/sticky_linux.go
Normal file
@@ -0,0 +1,217 @@
|
||||
// +build !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This implements userspace semantics of "sticky sockets", modeled after
|
||||
* WireGuard's kernelspace implementation. This is more or less a straight port
|
||||
* of the sticky-sockets.c example code:
|
||||
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
|
||||
*
|
||||
* Currently there is no way to achieve this within the net package:
|
||||
* See e.g. https://github.com/golang/go/issues/17930
|
||||
* So this code is remains platform dependent.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
netlinkSock, err := createNetlinkRouteSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
netlinkCancel, err := rwcancel.NewRWCancel(netlinkSock)
|
||||
if err != nil {
|
||||
unix.Close(netlinkSock)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go device.routineRouteListener(bind, netlinkSock, netlinkCancel)
|
||||
|
||||
return netlinkCancel, nil
|
||||
}
|
||||
|
||||
func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netlinkCancel *rwcancel.RWCancel) {
|
||||
type peerEndpointPtr struct {
|
||||
peer *Peer
|
||||
endpoint *conn.Endpoint
|
||||
}
|
||||
var reqPeer map[uint32]peerEndpointPtr
|
||||
var reqPeerLock sync.Mutex
|
||||
|
||||
defer unix.Close(netlinkSock)
|
||||
|
||||
for msg := make([]byte, 1<<16); ; {
|
||||
var err error
|
||||
var msgn int
|
||||
for {
|
||||
msgn, _, _, _, err = unix.Recvmsg(netlinkSock, msg[:], nil, 0)
|
||||
if err == nil || !rwcancel.RetryAfterError(err) {
|
||||
break
|
||||
}
|
||||
if !netlinkCancel.ReadyRead() {
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
|
||||
if uint(hdr.Len) > uint(len(remain)) {
|
||||
break
|
||||
}
|
||||
|
||||
switch hdr.Type {
|
||||
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
|
||||
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
|
||||
if uint(len(remain)) < uint(hdr.Len) {
|
||||
break
|
||||
}
|
||||
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
|
||||
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
|
||||
for {
|
||||
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
|
||||
break
|
||||
}
|
||||
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
|
||||
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
|
||||
break
|
||||
}
|
||||
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
|
||||
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
|
||||
reqPeerLock.Lock()
|
||||
if reqPeer == nil {
|
||||
reqPeerLock.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr, ok := reqPeer[hdr.Seq]
|
||||
reqPeerLock.Unlock()
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
pePtr.peer.Lock()
|
||||
if &pePtr.peer.endpoint != pePtr.endpoint {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
|
||||
pePtr.peer.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
reqPeerLock.Lock()
|
||||
reqPeer = make(map[uint32]peerEndpointPtr)
|
||||
reqPeerLock.Unlock()
|
||||
go func() {
|
||||
device.peers.RLock()
|
||||
i := uint32(1)
|
||||
for _, peer := range device.peers.keyMap {
|
||||
peer.RLock()
|
||||
if peer.endpoint == nil {
|
||||
peer.RUnlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.RUnlock()
|
||||
continue
|
||||
}
|
||||
if nativeEP.IsV6() || nativeEP.Src4().Ifindex == 0 {
|
||||
peer.RUnlock()
|
||||
break
|
||||
}
|
||||
nlmsg := struct {
|
||||
hdr unix.NlMsghdr
|
||||
msg unix.RtMsg
|
||||
dsthdr unix.RtAttr
|
||||
dst [4]byte
|
||||
srchdr unix.RtAttr
|
||||
src [4]byte
|
||||
markhdr unix.RtAttr
|
||||
mark uint32
|
||||
}{
|
||||
unix.NlMsghdr{
|
||||
Type: uint16(unix.RTM_GETROUTE),
|
||||
Flags: unix.NLM_F_REQUEST,
|
||||
Seq: i,
|
||||
},
|
||||
unix.RtMsg{
|
||||
Family: unix.AF_INET,
|
||||
Dst_len: 32,
|
||||
Src_len: 32,
|
||||
},
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_DST,
|
||||
},
|
||||
nativeEP.Dst4().Addr,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_SRC,
|
||||
},
|
||||
nativeEP.Src4().Src,
|
||||
unix.RtAttr{
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
},
|
||||
uint32(bind.LastMark()),
|
||||
}
|
||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||
reqPeerLock.Lock()
|
||||
reqPeer[i] = peerEndpointPtr{
|
||||
peer: peer,
|
||||
endpoint: &peer.endpoint,
|
||||
}
|
||||
reqPeerLock.Unlock()
|
||||
peer.RUnlock()
|
||||
i++
|
||||
_, err := netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
device.peers.RUnlock()
|
||||
}()
|
||||
}
|
||||
remain = remain[hdr.Len:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func createNetlinkRouteSocket() (int, error) {
|
||||
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: unix.RTMGRP_IPV4_ROUTE,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
unix.Close(sock)
|
||||
return -1, err
|
||||
}
|
||||
return sock, nil
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*
|
||||
* This is based heavily on timers.c from the kernel implementation.
|
||||
*/
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package device
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
@@ -30,7 +32,7 @@ func (s IPCError) ErrorCode() int64 {
|
||||
return s.int64
|
||||
}
|
||||
|
||||
func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
|
||||
func (device *Device) IpcGetOperation(socket *bufio.Writer) error {
|
||||
lines := make([]string, 0, 100)
|
||||
send := func(line string) {
|
||||
lines = append(lines, line)
|
||||
@@ -105,7 +107,7 @@ func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
||||
func (device *Device) IpcSetOperation(socket *bufio.Reader) error {
|
||||
scanner := bufio.NewScanner(socket)
|
||||
logError := device.log.Error
|
||||
logDebug := device.log.Debug
|
||||
@@ -138,7 +140,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
||||
switch key {
|
||||
case "private_key":
|
||||
var sk NoisePrivateKey
|
||||
err := sk.FromHex(value)
|
||||
err := sk.FromMaybeZeroHex(value)
|
||||
if err != nil {
|
||||
logError.Println("Failed to set private_key:", err)
|
||||
return &IPCError{ipc.IpcErrorInvalid}
|
||||
@@ -306,7 +308,7 @@ func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
|
||||
err := func() error {
|
||||
peer.Lock()
|
||||
defer peer.Unlock()
|
||||
endpoint, err := CreateEndpoint(value)
|
||||
endpoint, err := conn.CreateEndpoint(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -420,10 +422,20 @@ func (device *Device) IpcHandle(socket net.Conn) {
|
||||
|
||||
switch op {
|
||||
case "set=1\n":
|
||||
status = device.IpcSetOperation(buffered.Reader)
|
||||
err = device.IpcSetOperation(buffered.Reader)
|
||||
if err != nil && !errors.As(err, &status) {
|
||||
// should never happen
|
||||
device.log.Error.Println("Invalid UAPI error:", err)
|
||||
status = &IPCError{1}
|
||||
}
|
||||
|
||||
case "get=1\n":
|
||||
status = device.IpcGetOperation(buffered.Writer)
|
||||
err = device.IpcGetOperation(buffered.Writer)
|
||||
if err != nil && !errors.As(err, &status) {
|
||||
// should never happen
|
||||
device.log.Error.Println("Invalid UAPI error:", err)
|
||||
status = &IPCError{1}
|
||||
}
|
||||
|
||||
default:
|
||||
device.log.Error.Println("Invalid UAPI operation:", op)
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
package device
|
||||
|
||||
const WireGuardGoVersion = "0.0.20200121"
|
||||
const WireGuardGoVersion = "0.0.20201118"
|
||||
|
||||
9
go.mod
9
go.mod
@@ -1,10 +1,9 @@
|
||||
module golang.zx2c4.com/wireguard
|
||||
|
||||
go 1.12
|
||||
go 1.13
|
||||
|
||||
require (
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823
|
||||
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c
|
||||
golang.org/x/text v0.3.2
|
||||
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b
|
||||
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7
|
||||
)
|
||||
|
||||
19
go.sum
19
go.sum
@@ -1,14 +1,17 @@
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc h1:c0o/qxkaO2LF5t6fQrT4b5hzyggAkLLlCUjqfRxd8Q4=
|
||||
golang.org/x/crypto v0.0.0-20191002192127-34f69633bfdc/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9 h1:phUcVbl53swtrUN8kQEXFhUxPlIlWyBfKmidCu7P95o=
|
||||
golang.org/x/crypto v0.0.0-20201117144127-c1f2f97bffc9/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823 h1:Ypyv6BNJh07T1pUSrehkLemqPKXhus2MkfktJ91kRh4=
|
||||
golang.org/x/net v0.0.0-20191003171128-d98b1b443823/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b h1:uwuIcX0g4Yl1NC5XAz37xsr2lTtcqevgzYNVt49waME=
|
||||
golang.org/x/net v0.0.0-20201110031124-69a78807bb2b/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c h1:6Zx7DRlKXf79yfxuQ/7GqV3w2y7aDsk6bGg0MzF5RVU=
|
||||
golang.org/x/sys v0.0.0-20191003212358-c178f38b412c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7 h1:s330+6z/Ko3J0o6rvOcwXe5nzs7UT9tLKHoOXYn6uE0=
|
||||
golang.org/x/sys v0.0.0-20201117222635-ba5294a509c7/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
|
||||
@@ -2,32 +2,20 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
var socketDirectory = "/var/run/wireguard"
|
||||
|
||||
const (
|
||||
IpcErrorIO = -int64(unix.EIO)
|
||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||
socketName = "%s.sock"
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
listener net.Listener // unix socket listener
|
||||
connNew chan net.Conn
|
||||
@@ -84,10 +72,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
unixListener.SetUnlinkOnClose(true)
|
||||
}
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
socketPath := sockPath(name)
|
||||
|
||||
// watch for deletion of socket
|
||||
|
||||
@@ -146,58 +131,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
|
||||
return uapi, nil
|
||||
}
|
||||
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
|
||||
// check if path exist
|
||||
|
||||
err := os.MkdirAll(socketDirectory, 0755)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// open UNIX socket
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
|
||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldUmask := unix.Umask(0077)
|
||||
listener, err := func() (*net.UnixListener, error) {
|
||||
|
||||
// initial connection attempt
|
||||
|
||||
listener, err := net.ListenUnix("unix", addr)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// check if socket already active
|
||||
|
||||
_, err = net.Dial("unix", socketPath)
|
||||
if err == nil {
|
||||
return nil, errors.New("unix socket in use")
|
||||
}
|
||||
|
||||
// cleanup & attempt again
|
||||
|
||||
err = os.Remove(socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUnix("unix", addr)
|
||||
}()
|
||||
unix.Umask(oldUmask)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listener.File()
|
||||
}
|
||||
|
||||
@@ -1,31 +1,18 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
var socketDirectory = "/var/run/wireguard"
|
||||
|
||||
const (
|
||||
IpcErrorIO = -int64(unix.EIO)
|
||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||
socketName = "%s.sock"
|
||||
)
|
||||
|
||||
type UAPIListener struct {
|
||||
listener net.Listener // unix socket listener
|
||||
connNew chan net.Conn
|
||||
@@ -84,10 +71,7 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
|
||||
// watch for deletion of socket
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
socketPath := sockPath(name)
|
||||
|
||||
uapi.inotifyFd, err = unix.InotifyInit()
|
||||
if err != nil {
|
||||
@@ -143,58 +127,3 @@ func UAPIListen(name string, file *os.File) (net.Listener, error) {
|
||||
|
||||
return uapi, nil
|
||||
}
|
||||
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
|
||||
// check if path exist
|
||||
|
||||
err := os.MkdirAll(socketDirectory, 0755)
|
||||
if err != nil && !os.IsExist(err) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// open UNIX socket
|
||||
|
||||
socketPath := path.Join(
|
||||
socketDirectory,
|
||||
fmt.Sprintf(socketName, name),
|
||||
)
|
||||
|
||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldUmask := unix.Umask(0077)
|
||||
listener, err := func() (*net.UnixListener, error) {
|
||||
|
||||
// initial connection attempt
|
||||
|
||||
listener, err := net.ListenUnix("unix", addr)
|
||||
if err == nil {
|
||||
return listener, nil
|
||||
}
|
||||
|
||||
// check if socket already active
|
||||
|
||||
_, err = net.Dial("unix", socketPath)
|
||||
if err == nil {
|
||||
return nil, errors.New("unix socket in use")
|
||||
}
|
||||
|
||||
// cleanup & attempt again
|
||||
|
||||
err = os.Remove(socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return net.ListenUnix("unix", addr)
|
||||
}()
|
||||
unix.Umask(oldUmask)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return listener.File()
|
||||
}
|
||||
|
||||
65
ipc/uapi_unix.go
Normal file
65
ipc/uapi_unix.go
Normal file
@@ -0,0 +1,65 @@
|
||||
// +build linux darwin freebsd openbsd
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
const (
|
||||
IpcErrorIO = -int64(unix.EIO)
|
||||
IpcErrorProtocol = -int64(unix.EPROTO)
|
||||
IpcErrorInvalid = -int64(unix.EINVAL)
|
||||
IpcErrorPortInUse = -int64(unix.EADDRINUSE)
|
||||
)
|
||||
|
||||
// socketDirectory is variable because it is modified by a linker
|
||||
// flag in wireguard-android.
|
||||
var socketDirectory = "/var/run/wireguard"
|
||||
|
||||
func sockPath(iface string) string {
|
||||
return fmt.Sprintf("%s/%s.sock", socketDirectory, iface)
|
||||
}
|
||||
|
||||
func UAPIOpen(name string) (*os.File, error) {
|
||||
if err := os.MkdirAll(socketDirectory, 0755); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
socketPath := sockPath(name)
|
||||
addr, err := net.ResolveUnixAddr("unix", socketPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
oldUmask := unix.Umask(0077)
|
||||
defer unix.Umask(oldUmask)
|
||||
|
||||
listener, err := net.ListenUnix("unix", addr)
|
||||
if err == nil {
|
||||
return listener.File()
|
||||
}
|
||||
|
||||
// Test socket, if not in use cleanup and try again.
|
||||
if _, err := net.Dial("unix", socketPath); err == nil {
|
||||
return nil, errors.New("unix socket in use")
|
||||
}
|
||||
if err := os.Remove(socketPath); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
listener, err = net.ListenUnix("unix", addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return listener.File()
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ipc
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2005 Microsoft
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
package winpipe
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2005 Microsoft
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winpipe
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2005 Microsoft
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package winpipe
|
||||
|
||||
24
main.go
24
main.go
@@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
@@ -41,18 +41,16 @@ func warning() {
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
||||
fmt.Fprintln(os.Stderr, "W G")
|
||||
fmt.Fprintln(os.Stderr, "W You are running this software on a Linux kernel, G")
|
||||
fmt.Fprintln(os.Stderr, "W which is probably unnecessary and misguided. This G")
|
||||
fmt.Fprintln(os.Stderr, "W is because the Linux kernel has built-in first G")
|
||||
fmt.Fprintln(os.Stderr, "W class support for WireGuard, and this support is G")
|
||||
fmt.Fprintln(os.Stderr, "W much more refined than this slower userspace G")
|
||||
fmt.Fprintln(os.Stderr, "W implementation. For more information on G")
|
||||
fmt.Fprintln(os.Stderr, "W installing the kernel module, please visit: G")
|
||||
fmt.Fprintln(os.Stderr, "W https://www.wireguard.com/install G")
|
||||
fmt.Fprintln(os.Stderr, "W G")
|
||||
fmt.Fprintln(os.Stderr, "WARNING WARNING WARNING WARNING WARNING WARNING WARNING")
|
||||
fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐")
|
||||
fmt.Fprintln(os.Stderr, "│ │")
|
||||
fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, │")
|
||||
fmt.Fprintln(os.Stderr, "│ because the Linux kernel has built-in first │")
|
||||
fmt.Fprintln(os.Stderr, "│ class support for WireGuard, which will be │")
|
||||
fmt.Fprintln(os.Stderr, "│ faster, slicker, and better integrated. For │")
|
||||
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
|
||||
fmt.Fprintln(os.Stderr, "│ please visit: <https://wireguard.com/install>. │")
|
||||
fmt.Fprintln(os.Stderr, "│ │")
|
||||
fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘")
|
||||
}
|
||||
|
||||
func main() {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package main
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
@@ -20,21 +20,23 @@ const (
|
||||
)
|
||||
|
||||
type RatelimiterEntry struct {
|
||||
sync.Mutex
|
||||
mu sync.Mutex
|
||||
lastTime time.Time
|
||||
tokens int64
|
||||
}
|
||||
|
||||
type Ratelimiter struct {
|
||||
sync.RWMutex
|
||||
stopReset chan struct{}
|
||||
mu sync.RWMutex
|
||||
timeNow func() time.Time
|
||||
|
||||
stopReset chan struct{} // send to reset, close to stop
|
||||
tableIPv4 map[[net.IPv4len]byte]*RatelimiterEntry
|
||||
tableIPv6 map[[net.IPv6len]byte]*RatelimiterEntry
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) Close() {
|
||||
rate.Lock()
|
||||
defer rate.Unlock()
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
if rate.stopReset != nil {
|
||||
close(rate.stopReset)
|
||||
@@ -42,11 +44,14 @@ func (rate *Ratelimiter) Close() {
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) Init() {
|
||||
rate.Lock()
|
||||
defer rate.Unlock()
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
if rate.timeNow == nil {
|
||||
rate.timeNow = time.Now
|
||||
}
|
||||
|
||||
// stop any ongoing garbage collection routine
|
||||
|
||||
if rate.stopReset != nil {
|
||||
close(rate.stopReset)
|
||||
}
|
||||
@@ -55,48 +60,50 @@ func (rate *Ratelimiter) Init() {
|
||||
rate.tableIPv4 = make(map[[net.IPv4len]byte]*RatelimiterEntry)
|
||||
rate.tableIPv6 = make(map[[net.IPv6len]byte]*RatelimiterEntry)
|
||||
|
||||
// start garbage collection routine
|
||||
stopReset := rate.stopReset // store in case Init is called again.
|
||||
|
||||
// Start garbage collection routine.
|
||||
go func() {
|
||||
ticker := time.NewTicker(time.Second)
|
||||
ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case _, ok := <-rate.stopReset:
|
||||
case _, ok := <-stopReset:
|
||||
ticker.Stop()
|
||||
if ok {
|
||||
ticker = time.NewTicker(time.Second)
|
||||
} else {
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
ticker = time.NewTicker(time.Second)
|
||||
case <-ticker.C:
|
||||
func() {
|
||||
rate.Lock()
|
||||
defer rate.Unlock()
|
||||
if rate.cleanup() {
|
||||
ticker.Stop()
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) cleanup() (empty bool) {
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
for key, entry := range rate.tableIPv4 {
|
||||
entry.Lock()
|
||||
if time.Since(entry.lastTime) > garbageCollectTime {
|
||||
entry.mu.Lock()
|
||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||
delete(rate.tableIPv4, key)
|
||||
}
|
||||
entry.Unlock()
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
for key, entry := range rate.tableIPv6 {
|
||||
entry.Lock()
|
||||
if time.Since(entry.lastTime) > garbageCollectTime {
|
||||
entry.mu.Lock()
|
||||
if rate.timeNow().Sub(entry.lastTime) > garbageCollectTime {
|
||||
delete(rate.tableIPv6, key)
|
||||
}
|
||||
entry.Unlock()
|
||||
entry.mu.Unlock()
|
||||
}
|
||||
|
||||
if len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0 {
|
||||
ticker.Stop()
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}()
|
||||
return len(rate.tableIPv4) == 0 && len(rate.tableIPv6) == 0
|
||||
}
|
||||
|
||||
func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||
@@ -109,7 +116,7 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||
IPv4 := ip.To4()
|
||||
IPv6 := ip.To16()
|
||||
|
||||
rate.RLock()
|
||||
rate.mu.RLock()
|
||||
|
||||
if IPv4 != nil {
|
||||
copy(keyIPv4[:], IPv4)
|
||||
@@ -119,15 +126,15 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||
entry = rate.tableIPv6[keyIPv6]
|
||||
}
|
||||
|
||||
rate.RUnlock()
|
||||
rate.mu.RUnlock()
|
||||
|
||||
// make new entry if not found
|
||||
|
||||
if entry == nil {
|
||||
entry = new(RatelimiterEntry)
|
||||
entry.tokens = maxTokens - packetCost
|
||||
entry.lastTime = time.Now()
|
||||
rate.Lock()
|
||||
entry.lastTime = rate.timeNow()
|
||||
rate.mu.Lock()
|
||||
if IPv4 != nil {
|
||||
rate.tableIPv4[keyIPv4] = entry
|
||||
if len(rate.tableIPv4) == 1 && len(rate.tableIPv6) == 0 {
|
||||
@@ -139,14 +146,14 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||
rate.stopReset <- struct{}{}
|
||||
}
|
||||
}
|
||||
rate.Unlock()
|
||||
rate.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
|
||||
// add tokens to entry
|
||||
|
||||
entry.Lock()
|
||||
now := time.Now()
|
||||
entry.mu.Lock()
|
||||
now := rate.timeNow()
|
||||
entry.tokens += now.Sub(entry.lastTime).Nanoseconds()
|
||||
entry.lastTime = now
|
||||
if entry.tokens > maxTokens {
|
||||
@@ -157,9 +164,9 @@ func (rate *Ratelimiter) Allow(ip net.IP) bool {
|
||||
|
||||
if entry.tokens > packetCost {
|
||||
entry.tokens -= packetCost
|
||||
entry.Unlock()
|
||||
entry.mu.Unlock()
|
||||
return true
|
||||
}
|
||||
entry.Unlock()
|
||||
entry.mu.Unlock()
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package ratelimiter
|
||||
@@ -11,22 +11,21 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
type RatelimiterResult struct {
|
||||
type result struct {
|
||||
allowed bool
|
||||
text string
|
||||
wait time.Duration
|
||||
}
|
||||
|
||||
func TestRatelimiter(t *testing.T) {
|
||||
var rate Ratelimiter
|
||||
var expectedResults []result
|
||||
|
||||
var ratelimiter Ratelimiter
|
||||
var expectedResults []RatelimiterResult
|
||||
|
||||
Nano := func(nano int64) time.Duration {
|
||||
nano := func(nano int64) time.Duration {
|
||||
return time.Nanosecond * time.Duration(nano)
|
||||
}
|
||||
|
||||
Add := func(res RatelimiterResult) {
|
||||
add := func(res result) {
|
||||
expectedResults = append(
|
||||
expectedResults,
|
||||
res,
|
||||
@@ -34,40 +33,40 @@ func TestRatelimiter(t *testing.T) {
|
||||
}
|
||||
|
||||
for i := 0; i < packetsBurstable; i++ {
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: true,
|
||||
text: "initial burst",
|
||||
})
|
||||
}
|
||||
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: false,
|
||||
text: "after burst",
|
||||
})
|
||||
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: true,
|
||||
wait: Nano(time.Second.Nanoseconds() / packetsPerSecond),
|
||||
wait: nano(time.Second.Nanoseconds() / packetsPerSecond),
|
||||
text: "filling tokens for single packet",
|
||||
})
|
||||
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: false,
|
||||
text: "not having refilled enough",
|
||||
})
|
||||
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: true,
|
||||
wait: 2 * (Nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
||||
wait: 2 * (nano(time.Second.Nanoseconds() / packetsPerSecond)),
|
||||
text: "filling tokens for two packet burst",
|
||||
})
|
||||
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: true,
|
||||
text: "second packet in 2 packet burst",
|
||||
})
|
||||
|
||||
Add(RatelimiterResult{
|
||||
add(result{
|
||||
allowed: false,
|
||||
text: "packet following 2 packet burst",
|
||||
})
|
||||
@@ -89,14 +88,31 @@ func TestRatelimiter(t *testing.T) {
|
||||
net.ParseIP("3f0e:54a2:f5b4:cd19:a21d:58e1:3746:84c4"),
|
||||
}
|
||||
|
||||
ratelimiter.Init()
|
||||
now := time.Now()
|
||||
rate.timeNow = func() time.Time {
|
||||
return now
|
||||
}
|
||||
defer func() {
|
||||
// Lock to avoid data race with cleanup goroutine from Init.
|
||||
rate.mu.Lock()
|
||||
defer rate.mu.Unlock()
|
||||
|
||||
rate.timeNow = time.Now
|
||||
}()
|
||||
timeSleep := func(d time.Duration) {
|
||||
now = now.Add(d + 1)
|
||||
rate.cleanup()
|
||||
}
|
||||
|
||||
rate.Init()
|
||||
defer rate.Close()
|
||||
|
||||
for i, res := range expectedResults {
|
||||
time.Sleep(res.wait)
|
||||
timeSleep(res.wait)
|
||||
for _, ip := range ips {
|
||||
allowed := ratelimiter.Allow(ip)
|
||||
allowed := rate.Allow(ip)
|
||||
if allowed != res.allowed {
|
||||
t.Fatal("Test failed for", ip.String(), ", on:", i, "(", res.text, ")", "expected:", res.allowed, "got:", allowed)
|
||||
t.Fatalf("%d: %s: rate.Allow(%q)=%v, want %v", i, res.text, ip, allowed, res.allowed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
101
replay/replay.go
101
replay/replay.go
@@ -1,83 +1,62 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package replay implements an efficient anti-replay algorithm as specified in RFC 6479.
|
||||
package replay
|
||||
|
||||
/* Implementation of RFC6479
|
||||
* https://tools.ietf.org/html/rfc6479
|
||||
*
|
||||
* The implementation is not safe for concurrent use!
|
||||
*/
|
||||
type block uint64
|
||||
|
||||
const (
|
||||
// See: https://golang.org/src/math/big/arith.go
|
||||
_Wordm = ^uintptr(0)
|
||||
_WordLogSize = _Wordm>>8&1 + _Wordm>>16&1 + _Wordm>>32&1
|
||||
_WordSize = 1 << _WordLogSize
|
||||
blockBitLog = 6 // 1<<6 == 64 bits
|
||||
blockBits = 1 << blockBitLog // must be power of 2
|
||||
ringBlocks = 1 << 7 // must be power of 2
|
||||
windowSize = (ringBlocks - 1) * blockBits
|
||||
blockMask = ringBlocks - 1
|
||||
bitMask = blockBits - 1
|
||||
)
|
||||
|
||||
const (
|
||||
CounterRedundantBitsLog = _WordLogSize + 3
|
||||
CounterRedundantBits = _WordSize * 8
|
||||
CounterBitsTotal = 2048
|
||||
CounterWindowSize = uint64(CounterBitsTotal - CounterRedundantBits)
|
||||
)
|
||||
|
||||
const (
|
||||
BacktrackWords = CounterBitsTotal / _WordSize
|
||||
)
|
||||
|
||||
func minUint64(a uint64, b uint64) uint64 {
|
||||
if a > b {
|
||||
return b
|
||||
}
|
||||
return a
|
||||
// A Filter rejects replayed messages by checking if message counter value is
|
||||
// within a sliding window of previously received messages.
|
||||
// The zero value for Filter is an empty filter ready to use.
|
||||
// Filters are unsafe for concurrent use.
|
||||
type Filter struct {
|
||||
last uint64
|
||||
ring [ringBlocks]block
|
||||
}
|
||||
|
||||
type ReplayFilter struct {
|
||||
counter uint64
|
||||
backtrack [BacktrackWords]uintptr
|
||||
// Reset resets the filter to empty state.
|
||||
func (f *Filter) Reset() {
|
||||
f.last = 0
|
||||
f.ring[0] = 0
|
||||
}
|
||||
|
||||
func (filter *ReplayFilter) Init() {
|
||||
filter.counter = 0
|
||||
filter.backtrack[0] = 0
|
||||
}
|
||||
|
||||
func (filter *ReplayFilter) ValidateCounter(counter uint64, limit uint64) bool {
|
||||
// ValidateCounter checks if the counter should be accepted.
|
||||
// Overlimit counters (>= limit) are always rejected.
|
||||
func (f *Filter) ValidateCounter(counter uint64, limit uint64) bool {
|
||||
if counter >= limit {
|
||||
return false
|
||||
}
|
||||
|
||||
indexWord := counter >> CounterRedundantBitsLog
|
||||
|
||||
if counter > filter.counter {
|
||||
|
||||
// move window forward
|
||||
|
||||
current := filter.counter >> CounterRedundantBitsLog
|
||||
diff := minUint64(indexWord-current, BacktrackWords)
|
||||
for i := uint64(1); i <= diff; i++ {
|
||||
filter.backtrack[(current+i)%BacktrackWords] = 0
|
||||
indexBlock := counter >> blockBitLog
|
||||
if counter > f.last { // move window forward
|
||||
current := f.last >> blockBitLog
|
||||
diff := indexBlock - current
|
||||
if diff > ringBlocks {
|
||||
diff = ringBlocks // cap diff to clear the whole ring
|
||||
}
|
||||
filter.counter = counter
|
||||
|
||||
} else if filter.counter-counter > CounterWindowSize {
|
||||
|
||||
// behind current window
|
||||
|
||||
for i := current + 1; i <= current+diff; i++ {
|
||||
f.ring[i&blockMask] = 0
|
||||
}
|
||||
f.last = counter
|
||||
} else if f.last-counter > windowSize { // behind current window
|
||||
return false
|
||||
}
|
||||
|
||||
indexWord %= BacktrackWords
|
||||
indexBit := counter & uint64(CounterRedundantBits-1)
|
||||
|
||||
// check and set bit
|
||||
|
||||
oldValue := filter.backtrack[indexWord]
|
||||
newValue := oldValue | (1 << indexBit)
|
||||
filter.backtrack[indexWord] = newValue
|
||||
return oldValue != newValue
|
||||
indexBlock &= blockMask
|
||||
indexBit := counter & bitMask
|
||||
old := f.ring[indexBlock]
|
||||
new := old | 1<<indexBit
|
||||
f.ring[indexBlock] = new
|
||||
return old != new
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package replay
|
||||
@@ -14,22 +14,22 @@ import (
|
||||
*
|
||||
*/
|
||||
|
||||
const RejectAfterMessages = (1 << 64) - (1 << 4) - 1
|
||||
const RejectAfterMessages = 1<<64 - 1<<13 - 1
|
||||
|
||||
func TestReplay(t *testing.T) {
|
||||
var filter ReplayFilter
|
||||
var filter Filter
|
||||
|
||||
T_LIM := CounterWindowSize + 1
|
||||
const T_LIM = windowSize + 1
|
||||
|
||||
testNumber := 0
|
||||
T := func(n uint64, v bool) {
|
||||
T := func(n uint64, expected bool) {
|
||||
testNumber++
|
||||
if filter.ValidateCounter(n, RejectAfterMessages) != v {
|
||||
t.Fatal("Test", testNumber, "failed", n, v)
|
||||
if filter.ValidateCounter(n, RejectAfterMessages) != expected {
|
||||
t.Fatal("Test", testNumber, "failed", n, expected)
|
||||
}
|
||||
}
|
||||
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
|
||||
T(0, true) /* 1 */
|
||||
T(1, true) /* 2 */
|
||||
@@ -67,53 +67,53 @@ func TestReplay(t *testing.T) {
|
||||
T(0, false) /* 34 */
|
||||
|
||||
t.Log("Bulk test 1")
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
testNumber = 0
|
||||
for i := uint64(1); i <= CounterWindowSize; i++ {
|
||||
for i := uint64(1); i <= windowSize; i++ {
|
||||
T(i, true)
|
||||
}
|
||||
T(0, true)
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 2")
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
testNumber = 0
|
||||
for i := uint64(2); i <= CounterWindowSize+1; i++ {
|
||||
for i := uint64(2); i <= windowSize+1; i++ {
|
||||
T(i, true)
|
||||
}
|
||||
T(1, true)
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 3")
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize + 1; i > 0; i-- {
|
||||
for i := uint64(windowSize + 1); i > 0; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
|
||||
t.Log("Bulk test 4")
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize + 2; i > 1; i-- {
|
||||
for i := uint64(windowSize + 2); i > 1; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 5")
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize; i > 0; i-- {
|
||||
for i := uint64(windowSize); i > 0; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
T(CounterWindowSize+1, true)
|
||||
T(windowSize+1, true)
|
||||
T(0, false)
|
||||
|
||||
t.Log("Bulk test 6")
|
||||
filter.Init()
|
||||
filter.Reset()
|
||||
testNumber = 0
|
||||
for i := CounterWindowSize; i > 0; i-- {
|
||||
for i := uint64(windowSize); i > 0; i-- {
|
||||
T(i, true)
|
||||
}
|
||||
T(0, true)
|
||||
T(CounterWindowSize+1, true)
|
||||
T(windowSize+1, true)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
// +build !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package rwcancel
|
||||
|
||||
@@ -1,8 +1,12 @@
|
||||
// +build !windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
// Package rwcancel implements cancelable read/write operations on
|
||||
// a file descriptor.
|
||||
package rwcancel
|
||||
|
||||
import (
|
||||
|
||||
8
rwcancel/rwcancel_windows.go
Normal file
8
rwcancel/rwcancel_windows.go
Normal file
@@ -0,0 +1,8 @@
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
package rwcancel
|
||||
|
||||
type RWCancel struct {
|
||||
}
|
||||
|
||||
func (*RWCancel) Cancel() {}
|
||||
@@ -1,8 +1,8 @@
|
||||
// +build !linux
|
||||
// +build !linux,!windows
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package rwcancel
|
||||
@@ -10,5 +10,6 @@ package rwcancel
|
||||
import "golang.org/x/sys/unix"
|
||||
|
||||
func unixSelect(nfd int, r *unix.FdSet, w *unix.FdSet, e *unix.FdSet, timeout *unix.Timeval) error {
|
||||
return unix.Select(nfd, r, w, e, timeout)
|
||||
_, err := unix.Select(nfd, r, w, e, timeout)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package rwcancel
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
@@ -17,16 +17,19 @@ const whitenerMask = uint32(0x1000000 - 1)
|
||||
|
||||
type Timestamp [TimestampSize]byte
|
||||
|
||||
func Now() Timestamp {
|
||||
func stamp(t time.Time) Timestamp {
|
||||
var tai64n Timestamp
|
||||
now := time.Now()
|
||||
secs := base + uint64(now.Unix())
|
||||
nano := uint32(now.Nanosecond()) &^ whitenerMask
|
||||
secs := base + uint64(t.Unix())
|
||||
nano := uint32(t.Nanosecond()) &^ whitenerMask
|
||||
binary.BigEndian.PutUint64(tai64n[:], secs)
|
||||
binary.BigEndian.PutUint32(tai64n[8:], nano)
|
||||
return tai64n
|
||||
}
|
||||
|
||||
func Now() Timestamp {
|
||||
return stamp(time.Now())
|
||||
}
|
||||
|
||||
func (t1 Timestamp) After(t2 Timestamp) bool {
|
||||
return bytes.Compare(t1[:], t2[:]) > 0
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tai64n
|
||||
@@ -10,21 +10,31 @@ import (
|
||||
"time"
|
||||
)
|
||||
|
||||
/* Testing the essential property of the timestamp
|
||||
* as used by WireGuard.
|
||||
*/
|
||||
// Test that timestamps are monotonic as required by Wireguard and that
|
||||
// nanosecond-level information is whitened to prevent side channel attacks.
|
||||
func TestMonotonic(t *testing.T) {
|
||||
old := Now()
|
||||
for i := 0; i < 50; i++ {
|
||||
next := Now()
|
||||
if next.After(old) {
|
||||
t.Error("Whitening insufficient")
|
||||
startTime := time.Unix(0, 123456789) // a nontrivial bit pattern
|
||||
// Whitening should reduce timestamp granularity
|
||||
// to more than 10 but fewer than 20 milliseconds.
|
||||
tests := []struct {
|
||||
name string
|
||||
t1, t2 time.Time
|
||||
wantAfter bool
|
||||
}{
|
||||
{"after_10_ns", startTime, startTime.Add(10 * time.Nanosecond), false},
|
||||
{"after_10_us", startTime, startTime.Add(10 * time.Microsecond), false},
|
||||
{"after_1_ms", startTime, startTime.Add(time.Millisecond), false},
|
||||
{"after_10_ms", startTime, startTime.Add(10 * time.Millisecond), false},
|
||||
{"after_20_ms", startTime, startTime.Add(20 * time.Millisecond), true},
|
||||
}
|
||||
time.Sleep(time.Duration(whitenerMask)/time.Nanosecond + 1)
|
||||
next = Now()
|
||||
if !next.After(old) {
|
||||
t.Error("Not monotonically increasing on whitened nano-second scale")
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ts1, ts2 := stamp(tt.t1), stamp(tt.t2)
|
||||
got := ts2.After(ts1)
|
||||
if got != tt.wantAfter {
|
||||
t.Errorf("after = %v; want %v", got, tt.wantAfter)
|
||||
}
|
||||
old = next
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@@ -20,19 +20,6 @@ import (
|
||||
|
||||
const utunControlName = "com.apple.net.utun_control"
|
||||
|
||||
// _CTLIOCGINFO value derived from /usr/include/sys/{kern_control,ioccom}.h
|
||||
const _CTLIOCGINFO = (0x40000000 | 0x80000000) | ((100 & 0x1fff) << 16) | uint32(byte('N'))<<8 | 3
|
||||
|
||||
// sockaddr_ctl specifeid in /usr/include/sys/kern_control.h
|
||||
type sockaddrCtl struct {
|
||||
scLen uint8
|
||||
scFamily uint8
|
||||
ssSysaddr uint16
|
||||
scID uint32
|
||||
scUnit uint32
|
||||
scReserved [5]uint32
|
||||
}
|
||||
|
||||
type NativeTun struct {
|
||||
name string
|
||||
tunFile *os.File
|
||||
@@ -41,8 +28,6 @@ type NativeTun struct {
|
||||
routeSocket int
|
||||
}
|
||||
|
||||
var sockaddrCtlSize uintptr = 32
|
||||
|
||||
func retryInterfaceByIndex(index int) (iface *net.Interface, err error) {
|
||||
for i := 0; i < 20; i++ {
|
||||
iface, err = net.InterfaceByIndex(index)
|
||||
@@ -130,43 +115,21 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ctlInfo = &struct {
|
||||
ctlID uint32
|
||||
ctlName [96]byte
|
||||
}{}
|
||||
|
||||
copy(ctlInfo.ctlName[:], []byte(utunControlName))
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(_CTLIOCGINFO),
|
||||
uintptr(unsafe.Pointer(ctlInfo)),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("_CTLIOCGINFO: %v", errno)
|
||||
ctlInfo := &unix.CtlInfo{}
|
||||
copy(ctlInfo.Name[:], []byte(utunControlName))
|
||||
err = unix.IoctlCtlInfo(fd, ctlInfo)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("IoctlGetCtlInfo: %w", err)
|
||||
}
|
||||
|
||||
sc := sockaddrCtl{
|
||||
scLen: uint8(sockaddrCtlSize),
|
||||
scFamily: unix.AF_SYSTEM,
|
||||
ssSysaddr: 2,
|
||||
scID: ctlInfo.ctlID,
|
||||
scUnit: uint32(ifIndex) + 1,
|
||||
sc := &unix.SockaddrCtl{
|
||||
ID: ctlInfo.Id,
|
||||
Unit: uint32(ifIndex) + 1,
|
||||
}
|
||||
|
||||
scPointer := unsafe.Pointer(&sc)
|
||||
|
||||
_, _, errno = unix.RawSyscall(
|
||||
unix.SYS_CONNECT,
|
||||
uintptr(fd),
|
||||
uintptr(scPointer),
|
||||
uintptr(sockaddrCtlSize),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return nil, fmt.Errorf("SYS_CONNECT: %v", errno)
|
||||
err = unix.Connect(fd, sc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = syscall.SetNonblock(fd, true)
|
||||
@@ -230,27 +193,19 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
var ifName struct {
|
||||
name [16]byte
|
||||
}
|
||||
ifNameSize := uintptr(16)
|
||||
|
||||
var errno syscall.Errno
|
||||
var err error
|
||||
tun.operateOnFd(func(fd uintptr) {
|
||||
_, _, errno = unix.Syscall6(
|
||||
unix.SYS_GETSOCKOPT,
|
||||
fd,
|
||||
tun.name, err = unix.GetsockoptString(
|
||||
int(fd),
|
||||
2, /* #define SYSPROTO_CONTROL 2 */
|
||||
2, /* #define UTUN_OPT_IFNAME 2 */
|
||||
uintptr(unsafe.Pointer(&ifName)),
|
||||
uintptr(unsafe.Pointer(&ifNameSize)), 0)
|
||||
)
|
||||
})
|
||||
|
||||
if errno != 0 {
|
||||
return "", fmt.Errorf("SYS_GETSOCKOPT: %v", errno)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("GetSockoptString: %w", err)
|
||||
}
|
||||
|
||||
tun.name = string(ifName.name[:ifNameSize-1])
|
||||
return tun.name, nil
|
||||
}
|
||||
|
||||
@@ -320,11 +275,6 @@ func (tun *NativeTun) Close() error {
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
|
||||
// open datagram socket
|
||||
|
||||
var fd int
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
@@ -337,29 +287,18 @@ func (tun *NativeTun) setMTU(n int) error {
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
|
||||
var ifr [32]byte
|
||||
copy(ifr[:], tun.name)
|
||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCSIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
|
||||
if errno != 0 {
|
||||
return fmt.Errorf("failed to set MTU on %s", tun.name)
|
||||
var ifr unix.IfreqMTU
|
||||
copy(ifr.Name[:], tun.name)
|
||||
ifr.MTU = int32(n)
|
||||
err = unix.IoctlSetIfreqMTU(fd, &ifr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to set MTU on %s: %w", tun.name, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
|
||||
// open datagram socket
|
||||
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
unix.SOCK_DGRAM,
|
||||
@@ -372,19 +311,10 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
|
||||
var ifr [64]byte
|
||||
copy(ifr[:], tun.name)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
uintptr(unix.SIOCGIFMTU),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
if errno != 0 {
|
||||
return 0, fmt.Errorf("failed to get MTU on %s", tun.name)
|
||||
ifr, err := unix.IoctlGetIfreqMTU(fd, tun.name)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, err)
|
||||
}
|
||||
|
||||
return int(*(*int32)(unsafe.Pointer(&ifr[16]))), nil
|
||||
return int(ifr.MTU), nil
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@@ -287,7 +287,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %v", errno)
|
||||
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %w", errno)
|
||||
}
|
||||
|
||||
// Open control sockets
|
||||
@@ -328,7 +328,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %v", assignedName, errno)
|
||||
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %w", assignedName, errno)
|
||||
}
|
||||
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL
|
||||
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD
|
||||
@@ -341,7 +341,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %v", assignedName, errno)
|
||||
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %w", assignedName, errno)
|
||||
}
|
||||
|
||||
// Rename the interface
|
||||
@@ -359,7 +359,7 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
if errno != 0 {
|
||||
tunFile.Close()
|
||||
tunDestroy(assignedName)
|
||||
return nil, fmt.Errorf("Failed to rename %s to %s: %v", assignedName, name, errno)
|
||||
return nil, fmt.Errorf("Failed to rename %s to %s: %w", assignedName, name, errno)
|
||||
}
|
||||
|
||||
return CreateTUNFromFile(tunFile, mtu)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"syscall"
|
||||
@@ -32,7 +31,6 @@ const (
|
||||
type NativeTun struct {
|
||||
tunFile *os.File
|
||||
index int32 // if index
|
||||
name string // name of interface
|
||||
errors chan error // async error handling
|
||||
events chan Event // device related events
|
||||
nopi bool // the device was passed IFF_NO_PI
|
||||
@@ -40,6 +38,10 @@ type NativeTun struct {
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
hackListenerClosed sync.Mutex
|
||||
statusListenersShutdown chan struct{}
|
||||
|
||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||
nameCache string // name of interface
|
||||
nameErr error
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
@@ -64,14 +66,19 @@ func (tun *NativeTun) routineHackListener() {
|
||||
}
|
||||
switch err {
|
||||
case unix.EINVAL:
|
||||
// If the tunnel is up, it reports that write() is
|
||||
// allowed but we provided invalid data.
|
||||
tun.events <- EventUp
|
||||
case unix.EIO:
|
||||
// If the tunnel is down, it reports that no I/O
|
||||
// is possible, without checking our provided data.
|
||||
tun.events <- EventDown
|
||||
default:
|
||||
return
|
||||
}
|
||||
select {
|
||||
case <-time.After(time.Second):
|
||||
// nothing
|
||||
case <-tun.statusListenersShutdown:
|
||||
return
|
||||
}
|
||||
@@ -85,7 +92,7 @@ func createNetlinkSocket() (int, error) {
|
||||
}
|
||||
saddr := &unix.SockaddrNetlink{
|
||||
Family: unix.AF_NETLINK,
|
||||
Groups: uint32((1 << (unix.RTNLGRP_LINK - 1)) | (1 << (unix.RTNLGRP_IPV4_IFADDR - 1)) | (1 << (unix.RTNLGRP_IPV6_IFADDR - 1))),
|
||||
Groups: unix.RTMGRP_LINK | unix.RTMGRP_IPV4_IFADDR | unix.RTMGRP_IPV6_IFADDR,
|
||||
}
|
||||
err = unix.Bind(sock, saddr)
|
||||
if err != nil {
|
||||
@@ -126,6 +133,7 @@ func (tun *NativeTun) routineNetlinkListener() {
|
||||
default:
|
||||
}
|
||||
|
||||
wasEverUp := false
|
||||
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
|
||||
|
||||
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
|
||||
@@ -149,11 +157,17 @@ func (tun *NativeTun) routineNetlinkListener() {
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING != 0 {
|
||||
tun.events <- EventUp
|
||||
wasEverUp = true
|
||||
}
|
||||
|
||||
if info.Flags&unix.IFF_RUNNING == 0 {
|
||||
// Don't emit EventDown before we've ever emitted EventUp.
|
||||
// This avoids a startup race with HackListener, which
|
||||
// might detect Up before we have finished reporting Down.
|
||||
if wasEverUp {
|
||||
tun.events <- EventDown
|
||||
}
|
||||
}
|
||||
|
||||
tun.events <- EventMTUUpdate
|
||||
|
||||
@@ -164,11 +178,6 @@ func (tun *NativeTun) routineNetlinkListener() {
|
||||
}
|
||||
}
|
||||
|
||||
func (tun *NativeTun) isUp() (bool, error) {
|
||||
inter, err := net.InterfaceByName(tun.name)
|
||||
return inter.Flags&net.FlagUp != 0, err
|
||||
}
|
||||
|
||||
func getIFIndex(name string) (int32, error) {
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
@@ -198,6 +207,11 @@ func getIFIndex(name string) (int32, error) {
|
||||
}
|
||||
|
||||
func (tun *NativeTun) setMTU(n int) error {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
@@ -212,9 +226,8 @@ func (tun *NativeTun) setMTU(n int) error {
|
||||
defer unix.Close(fd)
|
||||
|
||||
// do ioctl call
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], tun.name)
|
||||
copy(ifr[:], name)
|
||||
*(*uint32)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = uint32(n)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
@@ -231,6 +244,11 @@ func (tun *NativeTun) setMTU(n int) error {
|
||||
}
|
||||
|
||||
func (tun *NativeTun) MTU() (int, error) {
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
// open datagram socket
|
||||
fd, err := unix.Socket(
|
||||
unix.AF_INET,
|
||||
@@ -247,7 +265,7 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
// do ioctl call
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
copy(ifr[:], tun.name)
|
||||
copy(ifr[:], name)
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(fd),
|
||||
@@ -262,6 +280,15 @@ func (tun *NativeTun) MTU() (int, error) {
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Name() (string, error) {
|
||||
tun.nameOnce.Do(tun.initNameCache)
|
||||
return tun.nameCache, tun.nameErr
|
||||
}
|
||||
|
||||
func (tun *NativeTun) initNameCache() {
|
||||
tun.nameCache, tun.nameErr = tun.nameSlow()
|
||||
}
|
||||
|
||||
func (tun *NativeTun) nameSlow() (string, error) {
|
||||
sysconn, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return "", err
|
||||
@@ -282,13 +309,11 @@ func (tun *NativeTun) Name() (string, error) {
|
||||
if errno != 0 {
|
||||
return "", errors.New("failed to get name of TUN device: " + errno.Error())
|
||||
}
|
||||
nullStr := ifr[:]
|
||||
i := bytes.IndexByte(nullStr, 0)
|
||||
if i != -1 {
|
||||
nullStr = nullStr[:i]
|
||||
name := ifr[:]
|
||||
if i := bytes.IndexByte(name, 0); i != -1 {
|
||||
name = name[:i]
|
||||
}
|
||||
tun.name = string(nullStr)
|
||||
return tun.name, nil
|
||||
return string(name), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
@@ -367,6 +392,9 @@ func (tun *NativeTun) Close() error {
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, os.O_RDWR, 0)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil, fmt.Errorf("CreateTUN(%q) failed; %s does not exist", name, cloneDevicePath)
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -408,16 +436,15 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
statusListenersShutdown: make(chan struct{}),
|
||||
nopi: false,
|
||||
}
|
||||
var err error
|
||||
|
||||
_, err = tun.Name()
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start event listener
|
||||
|
||||
tun.index, err = getIFIndex(tun.name)
|
||||
tun.index, err = getIFIndex(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2018-2019 WireGuard LLC. All Rights Reserved.
|
||||
* Copyright (C) 2018-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tun
|
||||
@@ -9,10 +9,9 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
_ "unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
@@ -33,18 +32,26 @@ type rateJuggler struct {
|
||||
}
|
||||
|
||||
type NativeTun struct {
|
||||
wt *wintun.Interface
|
||||
wt *wintun.Adapter
|
||||
handle windows.Handle
|
||||
close bool
|
||||
events chan Event
|
||||
errors chan error
|
||||
forcedMTU int
|
||||
rate rateJuggler
|
||||
rings *wintun.RingDescriptor
|
||||
writeLock sync.Mutex
|
||||
session wintun.Session
|
||||
readWait windows.Handle
|
||||
}
|
||||
|
||||
const WintunPool = wintun.Pool("WireGuard")
|
||||
var WintunPool *wintun.Pool
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
WintunPool, err = wintun.MakePool("WireGuard")
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("Failed to make pool: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
//go:linkname procyield runtime.procyield
|
||||
func procyield(cycles uint32)
|
||||
@@ -66,20 +73,20 @@ func CreateTUN(ifname string, mtu int) (Device, error) {
|
||||
//
|
||||
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
||||
var err error
|
||||
var wt *wintun.Interface
|
||||
var wt *wintun.Adapter
|
||||
|
||||
// Does an interface with this name already exist?
|
||||
wt, err = WintunPool.GetInterface(ifname)
|
||||
wt, err = WintunPool.OpenAdapter(ifname)
|
||||
if err == nil {
|
||||
// If so, we delete it, in case it has weird residual configuration.
|
||||
_, err = wt.DeleteInterface()
|
||||
_, err = wt.Delete(true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error deleting already existing interface: %v", err)
|
||||
return nil, fmt.Errorf("Error deleting already existing interface: %w", err)
|
||||
}
|
||||
}
|
||||
wt, _, err = WintunPool.CreateInterface(ifname, requestedGUID)
|
||||
wt, _, err = WintunPool.CreateAdapter(ifname, requestedGUID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Error creating interface: %v", err)
|
||||
return nil, fmt.Errorf("Error creating interface: %w", err)
|
||||
}
|
||||
|
||||
forcedMTU := 1420
|
||||
@@ -95,17 +102,13 @@ func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu
|
||||
forcedMTU: forcedMTU,
|
||||
}
|
||||
|
||||
tun.rings, err = wintun.NewRingDescriptor()
|
||||
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
|
||||
if err != nil {
|
||||
tun.Close()
|
||||
return nil, fmt.Errorf("Error creating events: %v", err)
|
||||
}
|
||||
|
||||
tun.handle, err = tun.wt.Register(tun.rings)
|
||||
if err != nil {
|
||||
tun.Close()
|
||||
return nil, fmt.Errorf("Error registering rings: %v", err)
|
||||
_, err = tun.wt.Delete(false)
|
||||
close(tun.events)
|
||||
return nil, fmt.Errorf("Error starting session: %w", err)
|
||||
}
|
||||
tun.readWait = tun.session.ReadWaitEvent()
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
@@ -123,16 +126,10 @@ func (tun *NativeTun) Events() chan Event {
|
||||
|
||||
func (tun *NativeTun) Close() error {
|
||||
tun.close = true
|
||||
if tun.rings.Send.TailMoved != 0 {
|
||||
windows.SetEvent(tun.rings.Send.TailMoved) // wake the reader if it's sleeping
|
||||
}
|
||||
if tun.handle != windows.InvalidHandle {
|
||||
windows.CloseHandle(tun.handle)
|
||||
}
|
||||
tun.rings.Close()
|
||||
tun.session.End()
|
||||
var err error
|
||||
if tun.wt != nil {
|
||||
_, err = tun.wt.DeleteInterface()
|
||||
_, err = tun.wt.Delete(false)
|
||||
}
|
||||
close(tun.events)
|
||||
return err
|
||||
@@ -156,56 +153,34 @@ retry:
|
||||
return 0, err
|
||||
default:
|
||||
}
|
||||
if tun.close {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
buffHead := atomic.LoadUint32(&tun.rings.Send.Ring.Head)
|
||||
if buffHead >= wintun.PacketCapacity {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
start := nanotime()
|
||||
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
||||
var buffTail uint32
|
||||
for {
|
||||
buffTail = atomic.LoadUint32(&tun.rings.Send.Ring.Tail)
|
||||
if buffHead != buffTail {
|
||||
break
|
||||
}
|
||||
if tun.close {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
packet, err := tun.session.ReceivePacket()
|
||||
switch err {
|
||||
case nil:
|
||||
packetSize := len(packet)
|
||||
copy(buff[offset:], packet)
|
||||
tun.session.ReleaseReceivePacket(packet)
|
||||
tun.rate.update(uint64(packetSize))
|
||||
return packetSize, nil
|
||||
case windows.ERROR_NO_MORE_ITEMS:
|
||||
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
||||
windows.WaitForSingleObject(tun.rings.Send.TailMoved, windows.INFINITE)
|
||||
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
|
||||
goto retry
|
||||
}
|
||||
procyield(1)
|
||||
}
|
||||
if buffTail >= wintun.PacketCapacity {
|
||||
continue
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
return 0, os.ErrClosed
|
||||
case windows.ERROR_INVALID_DATA:
|
||||
return 0, errors.New("Send ring corrupt")
|
||||
}
|
||||
|
||||
buffContent := tun.rings.Send.Ring.Wrap(buffTail - buffHead)
|
||||
if buffContent < uint32(unsafe.Sizeof(wintun.PacketHeader{})) {
|
||||
return 0, errors.New("incomplete packet header in send ring")
|
||||
return 0, fmt.Errorf("Read failed: %w", err)
|
||||
}
|
||||
|
||||
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Send.Ring.Data[buffHead]))
|
||||
if packet.Size > wintun.PacketSizeMax {
|
||||
return 0, errors.New("packet too big in send ring")
|
||||
}
|
||||
|
||||
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packet.Size)
|
||||
if alignedPacketSize > buffContent {
|
||||
return 0, errors.New("incomplete packet in send ring")
|
||||
}
|
||||
|
||||
copy(buff[offset:], packet.Data[:packet.Size])
|
||||
buffHead = tun.rings.Send.Ring.Wrap(buffHead + alignedPacketSize)
|
||||
atomic.StoreUint32(&tun.rings.Send.Ring.Head, buffHead)
|
||||
tun.rate.update(uint64(packet.Size))
|
||||
return int(packet.Size), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Flush() error {
|
||||
@@ -217,36 +192,22 @@ func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
packetSize := uint32(len(buff) - offset)
|
||||
packetSize := len(buff) - offset
|
||||
tun.rate.update(uint64(packetSize))
|
||||
alignedPacketSize := wintun.PacketAlign(uint32(unsafe.Sizeof(wintun.PacketHeader{})) + packetSize)
|
||||
|
||||
tun.writeLock.Lock()
|
||||
defer tun.writeLock.Unlock()
|
||||
|
||||
buffHead := atomic.LoadUint32(&tun.rings.Receive.Ring.Head)
|
||||
if buffHead >= wintun.PacketCapacity {
|
||||
return 0, os.ErrClosed
|
||||
packet, err := tun.session.AllocateSendPacket(packetSize)
|
||||
if err == nil {
|
||||
copy(packet, buff[offset:])
|
||||
tun.session.SendPacket(packet)
|
||||
return packetSize, nil
|
||||
}
|
||||
|
||||
buffTail := atomic.LoadUint32(&tun.rings.Receive.Ring.Tail)
|
||||
if buffTail >= wintun.PacketCapacity {
|
||||
switch err {
|
||||
case windows.ERROR_HANDLE_EOF:
|
||||
return 0, os.ErrClosed
|
||||
}
|
||||
|
||||
buffSpace := tun.rings.Receive.Ring.Wrap(buffHead - buffTail - wintun.PacketAlignment)
|
||||
if alignedPacketSize > buffSpace {
|
||||
case windows.ERROR_BUFFER_OVERFLOW:
|
||||
return 0, nil // Dropping when ring is full.
|
||||
}
|
||||
|
||||
packet := (*wintun.Packet)(unsafe.Pointer(&tun.rings.Receive.Ring.Data[buffTail]))
|
||||
packet.Size = packetSize
|
||||
copy(packet.Data[:packetSize], buff[offset:])
|
||||
atomic.StoreUint32(&tun.rings.Receive.Ring.Tail, tun.rings.Receive.Ring.Wrap(buffTail+alignedPacketSize))
|
||||
if atomic.LoadInt32(&tun.rings.Receive.Ring.Alertable) != 0 {
|
||||
windows.SetEvent(tun.rings.Receive.TailMoved)
|
||||
}
|
||||
return int(packetSize), nil
|
||||
return 0, fmt.Errorf("Write failed: %w", err)
|
||||
}
|
||||
|
||||
// LUID returns Windows interface instance ID.
|
||||
@@ -254,9 +215,9 @@ func (tun *NativeTun) LUID() uint64 {
|
||||
return tun.wt.LUID()
|
||||
}
|
||||
|
||||
// Version returns the version of the Wintun driver and NDIS system currently loaded.
|
||||
func (tun *NativeTun) Version() (driverVersion string, ndisVersion string, err error) {
|
||||
return tun.wt.Version()
|
||||
// RunningVersion returns the running version of the Wintun driver.
|
||||
func (tun *NativeTun) RunningVersion() (version uint32, err error) {
|
||||
return wintun.RunningVersion()
|
||||
}
|
||||
|
||||
func (rate *rateJuggler) update(packetLen uint64) {
|
||||
|
||||
150
tun/tuntest/tuntest.go
Normal file
150
tun/tuntest/tuntest.go
Normal file
@@ -0,0 +1,150 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package tuntest
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun"
|
||||
)
|
||||
|
||||
func Ping(dst, src net.IP) []byte {
|
||||
localPort := uint16(1337)
|
||||
seq := uint16(0)
|
||||
|
||||
payload := make([]byte, 4)
|
||||
binary.BigEndian.PutUint16(payload[0:], localPort)
|
||||
binary.BigEndian.PutUint16(payload[2:], seq)
|
||||
|
||||
return genICMPv4(payload, dst, src)
|
||||
}
|
||||
|
||||
// Checksum is the "internet checksum" from https://tools.ietf.org/html/rfc1071.
|
||||
func checksum(buf []byte, initial uint16) uint16 {
|
||||
v := uint32(initial)
|
||||
for i := 0; i < len(buf)-1; i += 2 {
|
||||
v += uint32(binary.BigEndian.Uint16(buf[i:]))
|
||||
}
|
||||
if len(buf)%2 == 1 {
|
||||
v += uint32(buf[len(buf)-1]) << 8
|
||||
}
|
||||
for v > 0xffff {
|
||||
v = (v >> 16) + (v & 0xffff)
|
||||
}
|
||||
return ^uint16(v)
|
||||
}
|
||||
|
||||
func genICMPv4(payload []byte, dst, src net.IP) []byte {
|
||||
const (
|
||||
icmpv4ProtocolNumber = 1
|
||||
icmpv4Echo = 8
|
||||
icmpv4ChecksumOffset = 2
|
||||
icmpv4Size = 8
|
||||
ipv4Size = 20
|
||||
ipv4TotalLenOffset = 2
|
||||
ipv4ChecksumOffset = 10
|
||||
ttl = 65
|
||||
)
|
||||
|
||||
hdr := make([]byte, ipv4Size+icmpv4Size)
|
||||
|
||||
ip := hdr[0:ipv4Size]
|
||||
icmpv4 := hdr[ipv4Size : ipv4Size+icmpv4Size]
|
||||
|
||||
// https://tools.ietf.org/html/rfc792
|
||||
icmpv4[0] = icmpv4Echo // type
|
||||
icmpv4[1] = 0 // code
|
||||
chksum := ^checksum(icmpv4, checksum(payload, 0))
|
||||
binary.BigEndian.PutUint16(icmpv4[icmpv4ChecksumOffset:], chksum)
|
||||
|
||||
// https://tools.ietf.org/html/rfc760 section 3.1
|
||||
length := uint16(len(hdr) + len(payload))
|
||||
ip[0] = (4 << 4) | (ipv4Size / 4)
|
||||
binary.BigEndian.PutUint16(ip[ipv4TotalLenOffset:], length)
|
||||
ip[8] = ttl
|
||||
ip[9] = icmpv4ProtocolNumber
|
||||
copy(ip[12:], src.To4())
|
||||
copy(ip[16:], dst.To4())
|
||||
chksum = ^checksum(ip[:], 0)
|
||||
binary.BigEndian.PutUint16(ip[ipv4ChecksumOffset:], chksum)
|
||||
|
||||
var v []byte
|
||||
v = append(v, hdr...)
|
||||
v = append(v, payload...)
|
||||
return []byte(v)
|
||||
}
|
||||
|
||||
// TODO(crawshaw): find a reusable home for this. package devicetest?
|
||||
type ChannelTUN struct {
|
||||
Inbound chan []byte // incoming packets, closed on TUN close
|
||||
Outbound chan []byte // outbound packets, blocks forever on TUN close
|
||||
|
||||
closed chan struct{}
|
||||
events chan tun.Event
|
||||
tun chTun
|
||||
}
|
||||
|
||||
func NewChannelTUN() *ChannelTUN {
|
||||
c := &ChannelTUN{
|
||||
Inbound: make(chan []byte),
|
||||
Outbound: make(chan []byte),
|
||||
closed: make(chan struct{}),
|
||||
events: make(chan tun.Event, 1),
|
||||
}
|
||||
c.tun.c = c
|
||||
c.events <- tun.EventUp
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *ChannelTUN) TUN() tun.Device {
|
||||
return &c.tun
|
||||
}
|
||||
|
||||
type chTun struct {
|
||||
c *ChannelTUN
|
||||
}
|
||||
|
||||
func (t *chTun) File() *os.File { return nil }
|
||||
|
||||
func (t *chTun) Read(data []byte, offset int) (int, error) {
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||
case msg := <-t.c.Outbound:
|
||||
return copy(data[offset:], msg), nil
|
||||
}
|
||||
}
|
||||
|
||||
// Write is called by the wireguard device to deliver a packet for routing.
|
||||
func (t *chTun) Write(data []byte, offset int) (int, error) {
|
||||
if offset == -1 {
|
||||
close(t.c.closed)
|
||||
close(t.c.events)
|
||||
return 0, io.EOF
|
||||
}
|
||||
msg := make([]byte, len(data)-offset)
|
||||
copy(msg, data[offset:])
|
||||
select {
|
||||
case <-t.c.closed:
|
||||
return 0, io.EOF // TODO(crawshaw): what is the correct error value?
|
||||
case t.c.Inbound <- msg:
|
||||
return len(data) - offset, nil
|
||||
}
|
||||
}
|
||||
|
||||
const DefaultMTU = 1420
|
||||
|
||||
func (t *chTun) Flush() error { return nil }
|
||||
func (t *chTun) MTU() (int, error) { return DefaultMTU, nil }
|
||||
func (t *chTun) Name() (string, error) { return "loopbackTun1", nil }
|
||||
func (t *chTun) Events() chan tun.Event { return t.c.events }
|
||||
func (t *chTun) Close() error {
|
||||
t.Write(nil, -1)
|
||||
return nil
|
||||
}
|
||||
50
tun/wintun/dll_fromfile_windows.go
Normal file
50
tun/wintun/dll_fromfile_windows.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// +build !load_wintun_from_rsrc
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package wintun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type lazyDLL struct {
|
||||
Name string
|
||||
mu sync.Mutex
|
||||
module windows.Handle
|
||||
}
|
||||
|
||||
func (d *lazyDLL) Load() error {
|
||||
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
|
||||
return nil
|
||||
}
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
if d.module != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
LOAD_LIBRARY_SEARCH_APPLICATION_DIR = 0x00000200
|
||||
LOAD_LIBRARY_SEARCH_SYSTEM32 = 0x00000800
|
||||
)
|
||||
module, err := windows.LoadLibraryEx(d.Name, 0, LOAD_LIBRARY_SEARCH_APPLICATION_DIR|LOAD_LIBRARY_SEARCH_SYSTEM32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to load library: %w", err)
|
||||
}
|
||||
|
||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *lazyProc) nameToAddr() (uintptr, error) {
|
||||
return windows.GetProcAddress(p.dll.module, p.Name)
|
||||
}
|
||||
58
tun/wintun/dll_fromrsrc_windows.go
Normal file
58
tun/wintun/dll_fromrsrc_windows.go
Normal file
@@ -0,0 +1,58 @@
|
||||
// +build load_wintun_from_rsrc
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package wintun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/wintun/memmod"
|
||||
"golang.zx2c4.com/wireguard/tun/wintun/resource"
|
||||
)
|
||||
|
||||
type lazyDLL struct {
|
||||
Name string
|
||||
mu sync.Mutex
|
||||
module *memmod.Module
|
||||
}
|
||||
|
||||
func (d *lazyDLL) Load() error {
|
||||
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&d.module))) != nil {
|
||||
return nil
|
||||
}
|
||||
d.mu.Lock()
|
||||
defer d.mu.Unlock()
|
||||
if d.module != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
const ourModule windows.Handle = 0
|
||||
resInfo, err := resource.FindByName(ourModule, d.Name, resource.RT_RCDATA)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to find \"%v\" RCDATA resource: %w", d.Name, err)
|
||||
}
|
||||
data, err := resource.Load(ourModule, resInfo)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to load resource: %w", err)
|
||||
}
|
||||
module, err := memmod.LoadLibrary(data)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Unable to load library: %w", err)
|
||||
}
|
||||
|
||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&d.module)), unsafe.Pointer(module))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *lazyProc) nameToAddr() (uintptr, error) {
|
||||
return p.dll.module.ProcAddressByName(p.Name)
|
||||
}
|
||||
59
tun/wintun/dll_windows.go
Normal file
59
tun/wintun/dll_windows.go
Normal file
@@ -0,0 +1,59 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package wintun
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
func newLazyDLL(name string) *lazyDLL {
|
||||
return &lazyDLL{Name: name}
|
||||
}
|
||||
|
||||
func (d *lazyDLL) NewProc(name string) *lazyProc {
|
||||
return &lazyProc{dll: d, Name: name}
|
||||
}
|
||||
|
||||
type lazyProc struct {
|
||||
Name string
|
||||
mu sync.Mutex
|
||||
dll *lazyDLL
|
||||
addr uintptr
|
||||
}
|
||||
|
||||
func (p *lazyProc) Find() error {
|
||||
if atomic.LoadPointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr))) != nil {
|
||||
return nil
|
||||
}
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
if p.addr != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := p.dll.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error loading %v DLL: %w", p.dll.Name, err)
|
||||
}
|
||||
addr, err := p.nameToAddr()
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error getting %v address: %w", p.Name, err)
|
||||
}
|
||||
|
||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&p.addr)), unsafe.Pointer(addr))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *lazyProc) Addr() uintptr {
|
||||
err := p.Find()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return p.addr
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package iphlpapi
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
//sys convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) = iphlpapi.ConvertInterfaceLuidToGuid
|
||||
//sys convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) = iphlpapi.ConvertInterfaceAliasToLuid
|
||||
|
||||
func InterfaceGUIDFromAlias(alias string) (*windows.GUID, error) {
|
||||
var luid uint64
|
||||
var guid windows.GUID
|
||||
err := convertInterfaceAliasToLUID(windows.StringToUTF16Ptr(alias), &luid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = convertInterfaceLUIDToGUID(&luid, &guid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &guid, nil
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package iphlpapi
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go conversion_windows.go
|
||||
@@ -1,60 +0,0 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package iphlpapi
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return nil
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll")
|
||||
|
||||
procConvertInterfaceLuidToGuid = modiphlpapi.NewProc("ConvertInterfaceLuidToGuid")
|
||||
procConvertInterfaceAliasToLuid = modiphlpapi.NewProc("ConvertInterfaceAliasToLuid")
|
||||
)
|
||||
|
||||
func convertInterfaceLUIDToGUID(interfaceLUID *uint64, interfaceGUID *windows.GUID) (ret error) {
|
||||
r0, _, _ := syscall.Syscall(procConvertInterfaceLuidToGuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceLUID)), uintptr(unsafe.Pointer(interfaceGUID)), 0)
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func convertInterfaceAliasToLUID(interfaceAlias *uint16, interfaceLUID *uint64) (ret error) {
|
||||
r0, _, _ := syscall.Syscall(procConvertInterfaceAliasToLuid.Addr(), 2, uintptr(unsafe.Pointer(interfaceAlias)), uintptr(unsafe.Pointer(interfaceLUID)), 0)
|
||||
if r0 != 0 {
|
||||
ret = syscall.Errno(r0)
|
||||
}
|
||||
return
|
||||
}
|
||||
620
tun/wintun/memmod/memmod_windows.go
Normal file
620
tun/wintun/memmod/memmod_windows.go
Normal file
@@ -0,0 +1,620 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
type addressList struct {
|
||||
next *addressList
|
||||
address uintptr
|
||||
}
|
||||
|
||||
func (head *addressList) free() {
|
||||
for node := head; node != nil; node = node.next {
|
||||
windows.VirtualFree(node.address, 0, windows.MEM_RELEASE)
|
||||
}
|
||||
}
|
||||
|
||||
type Module struct {
|
||||
headers *IMAGE_NT_HEADERS
|
||||
codeBase uintptr
|
||||
modules []windows.Handle
|
||||
initialized bool
|
||||
isDLL bool
|
||||
isRelocated bool
|
||||
nameExports map[string]uint16
|
||||
entry uintptr
|
||||
blockedMemory *addressList
|
||||
}
|
||||
|
||||
func (module *Module) headerDirectory(idx int) *IMAGE_DATA_DIRECTORY {
|
||||
return &module.headers.OptionalHeader.DataDirectory[idx]
|
||||
}
|
||||
|
||||
func (module *Module) copySections(address uintptr, size uintptr, old_headers *IMAGE_NT_HEADERS) error {
|
||||
sections := module.headers.Sections()
|
||||
for i := range sections {
|
||||
if sections[i].SizeOfRawData == 0 {
|
||||
// Section doesn't contain data in the dll itself, but may define uninitialized data.
|
||||
sectionSize := old_headers.OptionalHeader.SectionAlignment
|
||||
if sectionSize == 0 {
|
||||
continue
|
||||
}
|
||||
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
|
||||
uintptr(sectionSize),
|
||||
windows.MEM_COMMIT,
|
||||
windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error allocating section: %w", err)
|
||||
}
|
||||
|
||||
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
|
||||
dest = module.codeBase + uintptr(sections[i].VirtualAddress)
|
||||
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
|
||||
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
|
||||
var dst []byte
|
||||
unsafeSlice(unsafe.Pointer(&dst), a2p(dest), int(sectionSize))
|
||||
for j := range dst {
|
||||
dst[j] = 0
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if size < uintptr(sections[i].PointerToRawData+sections[i].SizeOfRawData) {
|
||||
return errors.New("Incomplete section")
|
||||
}
|
||||
|
||||
// Commit memory block and copy data from dll.
|
||||
dest, err := windows.VirtualAlloc(module.codeBase+uintptr(sections[i].VirtualAddress),
|
||||
uintptr(sections[i].SizeOfRawData),
|
||||
windows.MEM_COMMIT,
|
||||
windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error allocating memory block: %w", err)
|
||||
}
|
||||
|
||||
// Always use position from file to support alignments smaller than page size (allocation above will align to page size).
|
||||
memcpy(
|
||||
module.codeBase+uintptr(sections[i].VirtualAddress),
|
||||
address+uintptr(sections[i].PointerToRawData),
|
||||
uintptr(sections[i].SizeOfRawData))
|
||||
// NOTE: On 64bit systems we truncate to 32bit here but expand again later when "PhysicalAddress" is used.
|
||||
sections[i].SetPhysicalAddress((uint32)(dest & 0xffffffff))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (module *Module) realSectionSize(section *IMAGE_SECTION_HEADER) uintptr {
|
||||
size := section.SizeOfRawData
|
||||
if size != 0 {
|
||||
return uintptr(size)
|
||||
}
|
||||
if (section.Characteristics & IMAGE_SCN_CNT_INITIALIZED_DATA) != 0 {
|
||||
return uintptr(module.headers.OptionalHeader.SizeOfInitializedData)
|
||||
}
|
||||
if (section.Characteristics & IMAGE_SCN_CNT_UNINITIALIZED_DATA) != 0 {
|
||||
return uintptr(module.headers.OptionalHeader.SizeOfUninitializedData)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
type sectionFinalizeData struct {
|
||||
address uintptr
|
||||
alignedAddress uintptr
|
||||
size uintptr
|
||||
characteristics uint32
|
||||
last bool
|
||||
}
|
||||
|
||||
func (module *Module) finalizeSection(sectionData *sectionFinalizeData) error {
|
||||
if sectionData.size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if (sectionData.characteristics & IMAGE_SCN_MEM_DISCARDABLE) != 0 {
|
||||
// Section is not needed any more and can safely be freed.
|
||||
if sectionData.address == sectionData.alignedAddress &&
|
||||
(sectionData.last ||
|
||||
(sectionData.size%uintptr(module.headers.OptionalHeader.SectionAlignment)) == 0) {
|
||||
// Only allowed to decommit whole pages.
|
||||
windows.VirtualFree(sectionData.address, sectionData.size, windows.MEM_DECOMMIT)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// determine protection flags based on characteristics
|
||||
var ProtectionFlags = [8]uint32{
|
||||
windows.PAGE_NOACCESS, // not writeable, not readable, not executable
|
||||
windows.PAGE_EXECUTE, // not writeable, not readable, executable
|
||||
windows.PAGE_READONLY, // not writeable, readable, not executable
|
||||
windows.PAGE_EXECUTE_READ, // not writeable, readable, executable
|
||||
windows.PAGE_WRITECOPY, // writeable, not readable, not executable
|
||||
windows.PAGE_EXECUTE_WRITECOPY, // writeable, not readable, executable
|
||||
windows.PAGE_READWRITE, // writeable, readable, not executable
|
||||
windows.PAGE_EXECUTE_READWRITE, // writeable, readable, executable
|
||||
}
|
||||
protect := ProtectionFlags[sectionData.characteristics>>29]
|
||||
if (sectionData.characteristics & IMAGE_SCN_MEM_NOT_CACHED) != 0 {
|
||||
protect |= windows.PAGE_NOCACHE
|
||||
}
|
||||
|
||||
// Change memory access flags.
|
||||
var oldProtect uint32
|
||||
err := windows.VirtualProtect(sectionData.address, sectionData.size, protect, &oldProtect)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error protecting memory page: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (module *Module) finalizeSections() error {
|
||||
sections := module.headers.Sections()
|
||||
imageOffset := module.headers.OptionalHeader.imageOffset()
|
||||
sectionData := sectionFinalizeData{}
|
||||
sectionData.address = uintptr(sections[0].PhysicalAddress()) | imageOffset
|
||||
sectionData.alignedAddress = alignDown(sectionData.address, uintptr(module.headers.OptionalHeader.SectionAlignment))
|
||||
sectionData.size = module.realSectionSize(§ions[0])
|
||||
sectionData.characteristics = sections[0].Characteristics
|
||||
|
||||
// Loop through all sections and change access flags.
|
||||
for i := uint16(1); i < module.headers.FileHeader.NumberOfSections; i++ {
|
||||
sectionAddress := uintptr(sections[i].PhysicalAddress()) | imageOffset
|
||||
alignedAddress := alignDown(sectionAddress, uintptr(module.headers.OptionalHeader.SectionAlignment))
|
||||
sectionSize := module.realSectionSize(§ions[i])
|
||||
// Combine access flags of all sections that share a page.
|
||||
// TODO: We currently share flags of a trailing large section with the page of a first small section. This should be optimized.
|
||||
if sectionData.alignedAddress == alignedAddress || sectionData.address+sectionData.size > alignedAddress {
|
||||
// Section shares page with previous.
|
||||
if (sections[i].Characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 || (sectionData.characteristics&IMAGE_SCN_MEM_DISCARDABLE) == 0 {
|
||||
sectionData.characteristics = (sectionData.characteristics | sections[i].Characteristics) &^ IMAGE_SCN_MEM_DISCARDABLE
|
||||
} else {
|
||||
sectionData.characteristics |= sections[i].Characteristics
|
||||
}
|
||||
sectionData.size = sectionAddress + sectionSize - sectionData.address
|
||||
continue
|
||||
}
|
||||
|
||||
err := module.finalizeSection(§ionData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finalizing section: %w", err)
|
||||
}
|
||||
sectionData.address = sectionAddress
|
||||
sectionData.alignedAddress = alignedAddress
|
||||
sectionData.size = sectionSize
|
||||
sectionData.characteristics = sections[i].Characteristics
|
||||
}
|
||||
sectionData.last = true
|
||||
err := module.finalizeSection(§ionData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error finalizing section: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (module *Module) executeTLS() {
|
||||
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_TLS)
|
||||
if directory.VirtualAddress == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
tls := (*IMAGE_TLS_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||
callback := tls.AddressOfCallbacks
|
||||
if callback != 0 {
|
||||
for {
|
||||
f := *(*uintptr)(a2p(callback))
|
||||
if f == 0 {
|
||||
break
|
||||
}
|
||||
syscall.Syscall(f, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), uintptr(0))
|
||||
callback += unsafe.Sizeof(f)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (module *Module) performBaseRelocation(delta uintptr) (relocated bool, err error) {
|
||||
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_BASERELOC)
|
||||
if directory.Size == 0 {
|
||||
return delta == 0, nil
|
||||
}
|
||||
|
||||
relocationHdr := (*IMAGE_BASE_RELOCATION)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||
for relocationHdr.VirtualAddress > 0 {
|
||||
dest := module.codeBase + uintptr(relocationHdr.VirtualAddress)
|
||||
|
||||
var relInfos []uint16
|
||||
unsafeSlice(
|
||||
unsafe.Pointer(&relInfos),
|
||||
a2p(uintptr(unsafe.Pointer(relocationHdr))+unsafe.Sizeof(*relocationHdr)),
|
||||
int((uintptr(relocationHdr.SizeOfBlock)-unsafe.Sizeof(*relocationHdr))/unsafe.Sizeof(relInfos[0])))
|
||||
for _, relInfo := range relInfos {
|
||||
// The upper 4 bits define the type of relocation.
|
||||
relType := relInfo >> 12
|
||||
// The lower 12 bits define the offset.
|
||||
relOffset := uintptr(relInfo & 0xfff)
|
||||
|
||||
switch relType {
|
||||
case IMAGE_REL_BASED_ABSOLUTE:
|
||||
// Skip relocation.
|
||||
|
||||
case IMAGE_REL_BASED_LOW:
|
||||
*(*uint16)(a2p(dest + relOffset)) += uint16(delta & 0xffff)
|
||||
break
|
||||
|
||||
case IMAGE_REL_BASED_HIGH:
|
||||
*(*uint16)(a2p(dest + relOffset)) += uint16(uint32(delta) >> 16)
|
||||
break
|
||||
|
||||
case IMAGE_REL_BASED_HIGHLOW:
|
||||
*(*uint32)(a2p(dest + relOffset)) += uint32(delta)
|
||||
|
||||
case IMAGE_REL_BASED_DIR64:
|
||||
*(*uint64)(a2p(dest + relOffset)) += uint64(delta)
|
||||
|
||||
case IMAGE_REL_BASED_THUMB_MOV32:
|
||||
inst := *(*uint32)(a2p(dest + relOffset))
|
||||
imm16 := ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
|
||||
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
|
||||
if (inst & 0x8000fbf0) != 0x0000f240 {
|
||||
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVW", inst)
|
||||
}
|
||||
imm16 += uint32(delta) & 0xffff
|
||||
hiDelta := (uint32(delta&0xffff0000) >> 16) + ((imm16 & 0xffff0000) >> 16)
|
||||
*(*uint32)(a2p(dest + relOffset)) = (inst & 0x8f00fbf0) + ((imm16 >> 1) & 0x0400) +
|
||||
((imm16 >> 12) & 0x000f) +
|
||||
((imm16 << 20) & 0x70000000) +
|
||||
((imm16 << 16) & 0xff0000)
|
||||
if hiDelta != 0 {
|
||||
inst = *(*uint32)(a2p(dest + relOffset + 4))
|
||||
imm16 = ((inst << 1) & 0x0800) + ((inst << 12) & 0xf000) +
|
||||
((inst >> 20) & 0x0700) + ((inst >> 16) & 0x00ff)
|
||||
if (inst & 0x8000fbf0) != 0x0000f2c0 {
|
||||
return false, fmt.Errorf("Wrong Thumb2 instruction %08x, expected MOVT", inst)
|
||||
}
|
||||
imm16 += hiDelta
|
||||
if imm16 > 0xffff {
|
||||
return false, fmt.Errorf("Resulting immediate value won't fit: %08x", imm16)
|
||||
}
|
||||
*(*uint32)(a2p(dest + relOffset + 4)) = (inst & 0x8f00fbf0) +
|
||||
((imm16 >> 1) & 0x0400) +
|
||||
((imm16 >> 12) & 0x000f) +
|
||||
((imm16 << 20) & 0x70000000) +
|
||||
((imm16 << 16) & 0xff0000)
|
||||
}
|
||||
|
||||
default:
|
||||
return false, fmt.Errorf("Unsupported relocation: %w", relType)
|
||||
}
|
||||
}
|
||||
|
||||
// Advance to next relocation block.
|
||||
relocationHdr = (*IMAGE_BASE_RELOCATION)(a2p(uintptr(unsafe.Pointer(relocationHdr)) + uintptr(relocationHdr.SizeOfBlock)))
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (module *Module) buildImportTable() error {
|
||||
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_IMPORT)
|
||||
if directory.Size == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
module.modules = make([]windows.Handle, 0, 16)
|
||||
importDesc := (*IMAGE_IMPORT_DESCRIPTOR)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||
for !isBadReadPtr(uintptr(unsafe.Pointer(importDesc)), unsafe.Sizeof(*importDesc)) && importDesc.Name != 0 {
|
||||
handle, err := loadLibraryA((*byte)(a2p(module.codeBase + uintptr(importDesc.Name))))
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error loading module: %w", err)
|
||||
}
|
||||
var thunkRef, funcRef *uintptr
|
||||
if importDesc.OriginalFirstThunk() != 0 {
|
||||
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.OriginalFirstThunk())))
|
||||
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
|
||||
} else {
|
||||
// No hint table.
|
||||
thunkRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
|
||||
funcRef = (*uintptr)(a2p(module.codeBase + uintptr(importDesc.FirstThunk)))
|
||||
}
|
||||
for *thunkRef != 0 {
|
||||
if IMAGE_SNAP_BY_ORDINAL(*thunkRef) {
|
||||
*funcRef, err = getProcAddress(handle, (*byte)(a2p(IMAGE_ORDINAL(*thunkRef))))
|
||||
} else {
|
||||
thunkData := (*IMAGE_IMPORT_BY_NAME)(a2p(module.codeBase + *thunkRef))
|
||||
*funcRef, err = getProcAddress(handle, &thunkData.Name[0])
|
||||
}
|
||||
if err != nil {
|
||||
windows.FreeLibrary(handle)
|
||||
return fmt.Errorf("Error getting function address: %w", err)
|
||||
}
|
||||
thunkRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(thunkRef)) + unsafe.Sizeof(*thunkRef)))
|
||||
funcRef = (*uintptr)(a2p(uintptr(unsafe.Pointer(funcRef)) + unsafe.Sizeof(*funcRef)))
|
||||
}
|
||||
module.modules = append(module.modules, handle)
|
||||
importDesc = (*IMAGE_IMPORT_DESCRIPTOR)(a2p(uintptr(unsafe.Pointer(importDesc)) + unsafe.Sizeof(*importDesc)))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (module *Module) buildNameExports() error {
|
||||
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
|
||||
if directory.Size == 0 {
|
||||
return errors.New("No export table found")
|
||||
}
|
||||
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||
if exports.NumberOfNames == 0 || exports.NumberOfFunctions == 0 {
|
||||
return errors.New("No functions exported")
|
||||
}
|
||||
if exports.NumberOfNames == 0 {
|
||||
return errors.New("No functions exported by name")
|
||||
}
|
||||
var nameRefs []uint32
|
||||
unsafeSlice(unsafe.Pointer(&nameRefs), a2p(module.codeBase+uintptr(exports.AddressOfNames)), int(exports.NumberOfNames))
|
||||
var ordinals []uint16
|
||||
unsafeSlice(unsafe.Pointer(&ordinals), a2p(module.codeBase+uintptr(exports.AddressOfNameOrdinals)), int(exports.NumberOfNames))
|
||||
module.nameExports = make(map[string]uint16)
|
||||
for i := range nameRefs {
|
||||
nameArray := windows.BytePtrToString((*byte)(a2p(module.codeBase + uintptr(nameRefs[i]))))
|
||||
module.nameExports[nameArray] = ordinals[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadLibrary loads module image to memory.
|
||||
func LoadLibrary(data []byte) (module *Module, err error) {
|
||||
addr := uintptr(unsafe.Pointer(&data[0]))
|
||||
size := uintptr(len(data))
|
||||
if size < unsafe.Sizeof(IMAGE_DOS_HEADER{}) {
|
||||
return nil, errors.New("Incomplete IMAGE_DOS_HEADER")
|
||||
}
|
||||
dosHeader := (*IMAGE_DOS_HEADER)(a2p(addr))
|
||||
if dosHeader.E_magic != IMAGE_DOS_SIGNATURE {
|
||||
return nil, fmt.Errorf("Not an MS-DOS binary (provided: %x, expected: %x)", dosHeader.E_magic, IMAGE_DOS_SIGNATURE)
|
||||
}
|
||||
if (size < uintptr(dosHeader.E_lfanew)+unsafe.Sizeof(IMAGE_NT_HEADERS{})) {
|
||||
return nil, errors.New("Incomplete IMAGE_NT_HEADERS")
|
||||
}
|
||||
oldHeader := (*IMAGE_NT_HEADERS)(a2p(addr + uintptr(dosHeader.E_lfanew)))
|
||||
if oldHeader.Signature != IMAGE_NT_SIGNATURE {
|
||||
return nil, fmt.Errorf("Not an NT binary (provided: %x, expected: %x)", oldHeader.Signature, IMAGE_NT_SIGNATURE)
|
||||
}
|
||||
if oldHeader.FileHeader.Machine != imageFileProcess {
|
||||
return nil, fmt.Errorf("Foreign platform (provided: %x, expected: %x)", oldHeader.FileHeader.Machine, imageFileProcess)
|
||||
}
|
||||
if (oldHeader.OptionalHeader.SectionAlignment & 1) != 0 {
|
||||
return nil, errors.New("Unaligned section")
|
||||
}
|
||||
lastSectionEnd := uintptr(0)
|
||||
sections := oldHeader.Sections()
|
||||
optionalSectionSize := oldHeader.OptionalHeader.SectionAlignment
|
||||
for i := range sections {
|
||||
var endOfSection uintptr
|
||||
if sections[i].SizeOfRawData == 0 {
|
||||
// Section without data in the DLL
|
||||
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(optionalSectionSize)
|
||||
} else {
|
||||
endOfSection = uintptr(sections[i].VirtualAddress) + uintptr(sections[i].SizeOfRawData)
|
||||
}
|
||||
if endOfSection > lastSectionEnd {
|
||||
lastSectionEnd = endOfSection
|
||||
}
|
||||
}
|
||||
alignedImageSize := alignUp(uintptr(oldHeader.OptionalHeader.SizeOfImage), uintptr(oldHeader.OptionalHeader.SectionAlignment))
|
||||
if alignedImageSize != alignUp(lastSectionEnd, uintptr(oldHeader.OptionalHeader.SectionAlignment)) {
|
||||
return nil, errors.New("Section is not page-aligned")
|
||||
}
|
||||
|
||||
module = &Module{isDLL: (oldHeader.FileHeader.Characteristics & IMAGE_FILE_DLL) != 0}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
module.Free()
|
||||
module = nil
|
||||
}
|
||||
}()
|
||||
|
||||
// Reserve memory for image of library.
|
||||
// TODO: Is it correct to commit the complete memory region at once? Calling DllEntry raises an exception if we don't.
|
||||
module.codeBase, err = windows.VirtualAlloc(oldHeader.OptionalHeader.ImageBase,
|
||||
alignedImageSize,
|
||||
windows.MEM_RESERVE|windows.MEM_COMMIT,
|
||||
windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
// Try to allocate memory at arbitrary position.
|
||||
module.codeBase, err = windows.VirtualAlloc(0,
|
||||
alignedImageSize,
|
||||
windows.MEM_RESERVE|windows.MEM_COMMIT,
|
||||
windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error allocating code: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
err = module.check4GBBoundaries(alignedImageSize)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error reallocating code: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
if size < uintptr(oldHeader.OptionalHeader.SizeOfHeaders) {
|
||||
err = errors.New("Incomplete headers")
|
||||
return
|
||||
}
|
||||
// Commit memory for headers.
|
||||
headers, err := windows.VirtualAlloc(module.codeBase,
|
||||
uintptr(oldHeader.OptionalHeader.SizeOfHeaders),
|
||||
windows.MEM_COMMIT,
|
||||
windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error allocating headers: %w", err)
|
||||
return
|
||||
}
|
||||
// Copy PE header to code.
|
||||
memcpy(headers, addr, uintptr(oldHeader.OptionalHeader.SizeOfHeaders))
|
||||
module.headers = (*IMAGE_NT_HEADERS)(a2p(headers + uintptr(dosHeader.E_lfanew)))
|
||||
|
||||
// Update position.
|
||||
module.headers.OptionalHeader.ImageBase = module.codeBase
|
||||
|
||||
// Copy sections from DLL file block to new memory location.
|
||||
err = module.copySections(addr, size, oldHeader)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error copying sections: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Adjust base address of imported data.
|
||||
locationDelta := module.headers.OptionalHeader.ImageBase - oldHeader.OptionalHeader.ImageBase
|
||||
if locationDelta != 0 {
|
||||
module.isRelocated, err = module.performBaseRelocation(locationDelta)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error relocating module: %w", err)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
module.isRelocated = true
|
||||
}
|
||||
|
||||
// Load required dlls and adjust function table of imports.
|
||||
err = module.buildImportTable()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error building import table: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Mark memory pages depending on section headers and release sections that are marked as "discardable".
|
||||
err = module.finalizeSections()
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error finalizing sections: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
// TLS callbacks are executed BEFORE the main loading.
|
||||
module.executeTLS()
|
||||
|
||||
// Get entry point of loaded module.
|
||||
if module.headers.OptionalHeader.AddressOfEntryPoint != 0 {
|
||||
module.entry = module.codeBase + uintptr(module.headers.OptionalHeader.AddressOfEntryPoint)
|
||||
if module.isDLL {
|
||||
// Notify library about attaching to process.
|
||||
r0, _, _ := syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_ATTACH), 0)
|
||||
successful := r0 != 0
|
||||
if !successful {
|
||||
err = windows.ERROR_DLL_INIT_FAILED
|
||||
return
|
||||
}
|
||||
module.initialized = true
|
||||
}
|
||||
}
|
||||
|
||||
module.buildNameExports()
|
||||
return
|
||||
}
|
||||
|
||||
// Free releases module resources and unloads it.
|
||||
func (module *Module) Free() {
|
||||
if module.initialized {
|
||||
// Notify library about detaching from process.
|
||||
syscall.Syscall(module.entry, 3, module.codeBase, uintptr(DLL_PROCESS_DETACH), 0)
|
||||
module.initialized = false
|
||||
}
|
||||
if module.modules != nil {
|
||||
// Free previously opened libraries.
|
||||
for _, handle := range module.modules {
|
||||
windows.FreeLibrary(handle)
|
||||
}
|
||||
module.modules = nil
|
||||
}
|
||||
if module.codeBase != 0 {
|
||||
windows.VirtualFree(module.codeBase, 0, windows.MEM_RELEASE)
|
||||
module.codeBase = 0
|
||||
}
|
||||
if module.blockedMemory != nil {
|
||||
module.blockedMemory.free()
|
||||
module.blockedMemory = nil
|
||||
}
|
||||
}
|
||||
|
||||
// ProcAddressByName returns function address by exported name.
|
||||
func (module *Module) ProcAddressByName(name string) (uintptr, error) {
|
||||
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
|
||||
if directory.Size == 0 {
|
||||
return 0, errors.New("No export table found")
|
||||
}
|
||||
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||
if module.nameExports == nil {
|
||||
return 0, errors.New("No functions exported by name")
|
||||
}
|
||||
if idx, ok := module.nameExports[name]; ok {
|
||||
if uint32(idx) > exports.NumberOfFunctions {
|
||||
return 0, errors.New("Ordinal number too high")
|
||||
}
|
||||
// AddressOfFunctions contains the RVAs to the "real" functions.
|
||||
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
|
||||
}
|
||||
return 0, errors.New("Function not found by name")
|
||||
}
|
||||
|
||||
// ProcAddressByOrdinal returns function address by exported ordinal.
|
||||
func (module *Module) ProcAddressByOrdinal(ordinal uint16) (uintptr, error) {
|
||||
directory := module.headerDirectory(IMAGE_DIRECTORY_ENTRY_EXPORT)
|
||||
if directory.Size == 0 {
|
||||
return 0, errors.New("No export table found")
|
||||
}
|
||||
exports := (*IMAGE_EXPORT_DIRECTORY)(a2p(module.codeBase + uintptr(directory.VirtualAddress)))
|
||||
if uint32(ordinal) < exports.Base {
|
||||
return 0, errors.New("Ordinal number too low")
|
||||
}
|
||||
idx := ordinal - uint16(exports.Base)
|
||||
if uint32(idx) > exports.NumberOfFunctions {
|
||||
return 0, errors.New("Ordinal number too high")
|
||||
}
|
||||
// AddressOfFunctions contains the RVAs to the "real" functions.
|
||||
return module.codeBase + uintptr(*(*uint32)(a2p(module.codeBase + uintptr(exports.AddressOfFunctions) + uintptr(idx)*4))), nil
|
||||
}
|
||||
|
||||
func alignDown(value, alignment uintptr) uintptr {
|
||||
return value & ^(alignment - 1)
|
||||
}
|
||||
|
||||
func alignUp(value, alignment uintptr) uintptr {
|
||||
return (value + alignment - 1) & ^(alignment - 1)
|
||||
}
|
||||
|
||||
func a2p(addr uintptr) unsafe.Pointer {
|
||||
return unsafe.Pointer(addr)
|
||||
}
|
||||
|
||||
func memcpy(dst, src, size uintptr) {
|
||||
var d, s []byte
|
||||
unsafeSlice(unsafe.Pointer(&d), a2p(dst), int(size))
|
||||
unsafeSlice(unsafe.Pointer(&s), a2p(src), int(size))
|
||||
copy(d, s)
|
||||
}
|
||||
|
||||
// unsafeSlice updates the slice slicePtr to be a slice
|
||||
// referencing the provided data with its length & capacity set to
|
||||
// lenCap.
|
||||
//
|
||||
// TODO: when Go 1.16 or Go 1.17 is the minimum supported version,
|
||||
// update callers to use unsafe.Slice instead of this.
|
||||
func unsafeSlice(slicePtr, data unsafe.Pointer, lenCap int) {
|
||||
type sliceHeader struct {
|
||||
Data unsafe.Pointer
|
||||
Len int
|
||||
Cap int
|
||||
}
|
||||
h := (*sliceHeader)(slicePtr)
|
||||
h.Data = data
|
||||
h.Len = lenCap
|
||||
h.Cap = lenCap
|
||||
}
|
||||
16
tun/wintun/memmod/memmod_windows_32.go
Normal file
16
tun/wintun/memmod/memmod_windows_32.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// +build 386 arm
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
|
||||
return
|
||||
}
|
||||
8
tun/wintun/memmod/memmod_windows_386.go
Normal file
8
tun/wintun/memmod/memmod_windows_386.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
const imageFileProcess = IMAGE_FILE_MACHINE_I386
|
||||
36
tun/wintun/memmod/memmod_windows_64.go
Normal file
36
tun/wintun/memmod/memmod_windows_64.go
Normal file
@@ -0,0 +1,36 @@
|
||||
// +build amd64 arm64
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func (opthdr *IMAGE_OPTIONAL_HEADER) imageOffset() uintptr {
|
||||
return uintptr(opthdr.ImageBase & 0xffffffff00000000)
|
||||
}
|
||||
|
||||
func (module *Module) check4GBBoundaries(alignedImageSize uintptr) (err error) {
|
||||
for (module.codeBase >> 32) < ((module.codeBase + alignedImageSize) >> 32) {
|
||||
node := &addressList{
|
||||
next: module.blockedMemory,
|
||||
address: module.codeBase,
|
||||
}
|
||||
module.blockedMemory = node
|
||||
module.codeBase, err = windows.VirtualAlloc(0,
|
||||
alignedImageSize,
|
||||
windows.MEM_RESERVE|windows.MEM_COMMIT,
|
||||
windows.PAGE_READWRITE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("Error allocating memory block: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
8
tun/wintun/memmod/memmod_windows_amd64.go
Normal file
8
tun/wintun/memmod/memmod_windows_amd64.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
const imageFileProcess = IMAGE_FILE_MACHINE_AMD64
|
||||
8
tun/wintun/memmod/memmod_windows_arm.go
Normal file
8
tun/wintun/memmod/memmod_windows_arm.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
const imageFileProcess = IMAGE_FILE_MACHINE_ARMNT
|
||||
8
tun/wintun/memmod/memmod_windows_arm64.go
Normal file
8
tun/wintun/memmod/memmod_windows_arm64.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
const imageFileProcess = IMAGE_FILE_MACHINE_ARM64
|
||||
8
tun/wintun/memmod/mksyscall.go
Normal file
8
tun/wintun/memmod/mksyscall.go
Normal file
@@ -0,0 +1,8 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go syscall_windows.go
|
||||
343
tun/wintun/memmod/syscall_windows.go
Normal file
343
tun/wintun/memmod/syscall_windows.go
Normal file
@@ -0,0 +1,343 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
import "unsafe"
|
||||
|
||||
const (
|
||||
IMAGE_DOS_SIGNATURE = 0x5A4D // MZ
|
||||
IMAGE_OS2_SIGNATURE = 0x454E // NE
|
||||
IMAGE_OS2_SIGNATURE_LE = 0x454C // LE
|
||||
IMAGE_VXD_SIGNATURE = 0x454C // LE
|
||||
IMAGE_NT_SIGNATURE = 0x00004550 // PE00
|
||||
)
|
||||
|
||||
// DOS .EXE header
|
||||
type IMAGE_DOS_HEADER struct {
|
||||
E_magic uint16 // Magic number
|
||||
E_cblp uint16 // Bytes on last page of file
|
||||
E_cp uint16 // Pages in file
|
||||
E_crlc uint16 // Relocations
|
||||
E_cparhdr uint16 // Size of header in paragraphs
|
||||
E_minalloc uint16 // Minimum extra paragraphs needed
|
||||
E_maxalloc uint16 // Maximum extra paragraphs needed
|
||||
E_ss uint16 // Initial (relative) SS value
|
||||
E_sp uint16 // Initial SP value
|
||||
E_csum uint16 // Checksum
|
||||
E_ip uint16 // Initial IP value
|
||||
E_cs uint16 // Initial (relative) CS value
|
||||
E_lfarlc uint16 // File address of relocation table
|
||||
E_ovno uint16 // Overlay number
|
||||
E_res [4]uint16 // Reserved words
|
||||
E_oemid uint16 // OEM identifier (for e_oeminfo)
|
||||
E_oeminfo uint16 // OEM information; e_oemid specific
|
||||
E_res2 [10]uint16 // Reserved words
|
||||
E_lfanew int32 // File address of new exe header
|
||||
}
|
||||
|
||||
// File header format
|
||||
type IMAGE_FILE_HEADER struct {
|
||||
Machine uint16
|
||||
NumberOfSections uint16
|
||||
TimeDateStamp uint32
|
||||
PointerToSymbolTable uint32
|
||||
NumberOfSymbols uint32
|
||||
SizeOfOptionalHeader uint16
|
||||
Characteristics uint16
|
||||
}
|
||||
|
||||
const (
|
||||
IMAGE_SIZEOF_FILE_HEADER = 20
|
||||
|
||||
IMAGE_FILE_RELOCS_STRIPPED = 0x0001 // Relocation info stripped from file.
|
||||
IMAGE_FILE_EXECUTABLE_IMAGE = 0x0002 // File is executable (i.e. no unresolved external references).
|
||||
IMAGE_FILE_LINE_NUMS_STRIPPED = 0x0004 // Line nunbers stripped from file.
|
||||
IMAGE_FILE_LOCAL_SYMS_STRIPPED = 0x0008 // Local symbols stripped from file.
|
||||
IMAGE_FILE_AGGRESIVE_WS_TRIM = 0x0010 // Aggressively trim working set
|
||||
IMAGE_FILE_LARGE_ADDRESS_AWARE = 0x0020 // App can handle >2gb addresses
|
||||
IMAGE_FILE_BYTES_REVERSED_LO = 0x0080 // Bytes of machine word are reversed.
|
||||
IMAGE_FILE_32BIT_MACHINE = 0x0100 // 32 bit word machine.
|
||||
IMAGE_FILE_DEBUG_STRIPPED = 0x0200 // Debugging info stripped from file in .DBG file
|
||||
IMAGE_FILE_REMOVABLE_RUN_FROM_SWAP = 0x0400 // If Image is on removable media, copy and run from the swap file.
|
||||
IMAGE_FILE_NET_RUN_FROM_SWAP = 0x0800 // If Image is on Net, copy and run from the swap file.
|
||||
IMAGE_FILE_SYSTEM = 0x1000 // System File.
|
||||
IMAGE_FILE_DLL = 0x2000 // File is a DLL.
|
||||
IMAGE_FILE_UP_SYSTEM_ONLY = 0x4000 // File should only be run on a UP machine
|
||||
IMAGE_FILE_BYTES_REVERSED_HI = 0x8000 // Bytes of machine word are reversed.
|
||||
|
||||
IMAGE_FILE_MACHINE_UNKNOWN = 0
|
||||
IMAGE_FILE_MACHINE_TARGET_HOST = 0x0001 // Useful for indicating we want to interact with the host and not a WoW guest.
|
||||
IMAGE_FILE_MACHINE_I386 = 0x014c // Intel 386.
|
||||
IMAGE_FILE_MACHINE_R3000 = 0x0162 // MIPS little-endian, 0x160 big-endian
|
||||
IMAGE_FILE_MACHINE_R4000 = 0x0166 // MIPS little-endian
|
||||
IMAGE_FILE_MACHINE_R10000 = 0x0168 // MIPS little-endian
|
||||
IMAGE_FILE_MACHINE_WCEMIPSV2 = 0x0169 // MIPS little-endian WCE v2
|
||||
IMAGE_FILE_MACHINE_ALPHA = 0x0184 // Alpha_AXP
|
||||
IMAGE_FILE_MACHINE_SH3 = 0x01a2 // SH3 little-endian
|
||||
IMAGE_FILE_MACHINE_SH3DSP = 0x01a3
|
||||
IMAGE_FILE_MACHINE_SH3E = 0x01a4 // SH3E little-endian
|
||||
IMAGE_FILE_MACHINE_SH4 = 0x01a6 // SH4 little-endian
|
||||
IMAGE_FILE_MACHINE_SH5 = 0x01a8 // SH5
|
||||
IMAGE_FILE_MACHINE_ARM = 0x01c0 // ARM Little-Endian
|
||||
IMAGE_FILE_MACHINE_THUMB = 0x01c2 // ARM Thumb/Thumb-2 Little-Endian
|
||||
IMAGE_FILE_MACHINE_ARMNT = 0x01c4 // ARM Thumb-2 Little-Endian
|
||||
IMAGE_FILE_MACHINE_AM33 = 0x01d3
|
||||
IMAGE_FILE_MACHINE_POWERPC = 0x01F0 // IBM PowerPC Little-Endian
|
||||
IMAGE_FILE_MACHINE_POWERPCFP = 0x01f1
|
||||
IMAGE_FILE_MACHINE_IA64 = 0x0200 // Intel 64
|
||||
IMAGE_FILE_MACHINE_MIPS16 = 0x0266 // MIPS
|
||||
IMAGE_FILE_MACHINE_ALPHA64 = 0x0284 // ALPHA64
|
||||
IMAGE_FILE_MACHINE_MIPSFPU = 0x0366 // MIPS
|
||||
IMAGE_FILE_MACHINE_MIPSFPU16 = 0x0466 // MIPS
|
||||
IMAGE_FILE_MACHINE_AXP64 = IMAGE_FILE_MACHINE_ALPHA64
|
||||
IMAGE_FILE_MACHINE_TRICORE = 0x0520 // Infineon
|
||||
IMAGE_FILE_MACHINE_CEF = 0x0CEF
|
||||
IMAGE_FILE_MACHINE_EBC = 0x0EBC // EFI Byte Code
|
||||
IMAGE_FILE_MACHINE_AMD64 = 0x8664 // AMD64 (K8)
|
||||
IMAGE_FILE_MACHINE_M32R = 0x9041 // M32R little-endian
|
||||
IMAGE_FILE_MACHINE_ARM64 = 0xAA64 // ARM64 Little-Endian
|
||||
IMAGE_FILE_MACHINE_CEE = 0xC0EE
|
||||
)
|
||||
|
||||
// Directory format
|
||||
type IMAGE_DATA_DIRECTORY struct {
|
||||
VirtualAddress uint32
|
||||
Size uint32
|
||||
}
|
||||
|
||||
const IMAGE_NUMBEROF_DIRECTORY_ENTRIES = 16
|
||||
|
||||
type IMAGE_NT_HEADERS struct {
|
||||
Signature uint32
|
||||
FileHeader IMAGE_FILE_HEADER
|
||||
OptionalHeader IMAGE_OPTIONAL_HEADER
|
||||
}
|
||||
|
||||
func (ntheader *IMAGE_NT_HEADERS) Sections() []IMAGE_SECTION_HEADER {
|
||||
return (*[0xffff]IMAGE_SECTION_HEADER)(unsafe.Pointer(
|
||||
(uintptr)(unsafe.Pointer(ntheader)) +
|
||||
unsafe.Offsetof(ntheader.OptionalHeader) +
|
||||
uintptr(ntheader.FileHeader.SizeOfOptionalHeader)))[:ntheader.FileHeader.NumberOfSections]
|
||||
}
|
||||
|
||||
const (
|
||||
IMAGE_DIRECTORY_ENTRY_EXPORT = 0 // Export Directory
|
||||
IMAGE_DIRECTORY_ENTRY_IMPORT = 1 // Import Directory
|
||||
IMAGE_DIRECTORY_ENTRY_RESOURCE = 2 // Resource Directory
|
||||
IMAGE_DIRECTORY_ENTRY_EXCEPTION = 3 // Exception Directory
|
||||
IMAGE_DIRECTORY_ENTRY_SECURITY = 4 // Security Directory
|
||||
IMAGE_DIRECTORY_ENTRY_BASERELOC = 5 // Base Relocation Table
|
||||
IMAGE_DIRECTORY_ENTRY_DEBUG = 6 // Debug Directory
|
||||
IMAGE_DIRECTORY_ENTRY_COPYRIGHT = 7 // (X86 usage)
|
||||
IMAGE_DIRECTORY_ENTRY_ARCHITECTURE = 7 // Architecture Specific Data
|
||||
IMAGE_DIRECTORY_ENTRY_GLOBALPTR = 8 // RVA of GP
|
||||
IMAGE_DIRECTORY_ENTRY_TLS = 9 // TLS Directory
|
||||
IMAGE_DIRECTORY_ENTRY_LOAD_CONFIG = 10 // Load Configuration Directory
|
||||
IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT = 11 // Bound Import Directory in headers
|
||||
IMAGE_DIRECTORY_ENTRY_IAT = 12 // Import Address Table
|
||||
IMAGE_DIRECTORY_ENTRY_DELAY_IMPORT = 13 // Delay Load Import Descriptors
|
||||
IMAGE_DIRECTORY_ENTRY_COM_DESCRIPTOR = 14 // COM Runtime descriptor
|
||||
)
|
||||
|
||||
const IMAGE_SIZEOF_SHORT_NAME = 8
|
||||
|
||||
// Section header format
|
||||
type IMAGE_SECTION_HEADER struct {
|
||||
Name [IMAGE_SIZEOF_SHORT_NAME]byte
|
||||
physicalAddressOrVirtualSize uint32
|
||||
VirtualAddress uint32
|
||||
SizeOfRawData uint32
|
||||
PointerToRawData uint32
|
||||
PointerToRelocations uint32
|
||||
PointerToLinenumbers uint32
|
||||
NumberOfRelocations uint16
|
||||
NumberOfLinenumbers uint16
|
||||
Characteristics uint32
|
||||
}
|
||||
|
||||
func (ishdr *IMAGE_SECTION_HEADER) PhysicalAddress() uint32 {
|
||||
return ishdr.physicalAddressOrVirtualSize
|
||||
}
|
||||
|
||||
func (ishdr *IMAGE_SECTION_HEADER) SetPhysicalAddress(addr uint32) {
|
||||
ishdr.physicalAddressOrVirtualSize = addr
|
||||
}
|
||||
|
||||
func (ishdr *IMAGE_SECTION_HEADER) VirtualSize() uint32 {
|
||||
return ishdr.physicalAddressOrVirtualSize
|
||||
}
|
||||
|
||||
func (ishdr *IMAGE_SECTION_HEADER) SetVirtualSize(addr uint32) {
|
||||
ishdr.physicalAddressOrVirtualSize = addr
|
||||
}
|
||||
|
||||
const (
|
||||
// Section characteristics.
|
||||
IMAGE_SCN_TYPE_REG = 0x00000000 // Reserved.
|
||||
IMAGE_SCN_TYPE_DSECT = 0x00000001 // Reserved.
|
||||
IMAGE_SCN_TYPE_NOLOAD = 0x00000002 // Reserved.
|
||||
IMAGE_SCN_TYPE_GROUP = 0x00000004 // Reserved.
|
||||
IMAGE_SCN_TYPE_NO_PAD = 0x00000008 // Reserved.
|
||||
IMAGE_SCN_TYPE_COPY = 0x00000010 // Reserved.
|
||||
|
||||
IMAGE_SCN_CNT_CODE = 0x00000020 // Section contains code.
|
||||
IMAGE_SCN_CNT_INITIALIZED_DATA = 0x00000040 // Section contains initialized data.
|
||||
IMAGE_SCN_CNT_UNINITIALIZED_DATA = 0x00000080 // Section contains uninitialized data.
|
||||
|
||||
IMAGE_SCN_LNK_OTHER = 0x00000100 // Reserved.
|
||||
IMAGE_SCN_LNK_INFO = 0x00000200 // Section contains comments or some other type of information.
|
||||
IMAGE_SCN_TYPE_OVER = 0x00000400 // Reserved.
|
||||
IMAGE_SCN_LNK_REMOVE = 0x00000800 // Section contents will not become part of image.
|
||||
IMAGE_SCN_LNK_COMDAT = 0x00001000 // Section contents comdat.
|
||||
IMAGE_SCN_MEM_PROTECTED = 0x00004000 // Obsolete.
|
||||
IMAGE_SCN_NO_DEFER_SPEC_EXC = 0x00004000 // Reset speculative exceptions handling bits in the TLB entries for this section.
|
||||
IMAGE_SCN_GPREL = 0x00008000 // Section content can be accessed relative to GP
|
||||
IMAGE_SCN_MEM_FARDATA = 0x00008000
|
||||
IMAGE_SCN_MEM_SYSHEAP = 0x00010000 // Obsolete.
|
||||
IMAGE_SCN_MEM_PURGEABLE = 0x00020000
|
||||
IMAGE_SCN_MEM_16BIT = 0x00020000
|
||||
IMAGE_SCN_MEM_LOCKED = 0x00040000
|
||||
IMAGE_SCN_MEM_PRELOAD = 0x00080000
|
||||
|
||||
IMAGE_SCN_ALIGN_1BYTES = 0x00100000 //
|
||||
IMAGE_SCN_ALIGN_2BYTES = 0x00200000 //
|
||||
IMAGE_SCN_ALIGN_4BYTES = 0x00300000 //
|
||||
IMAGE_SCN_ALIGN_8BYTES = 0x00400000 //
|
||||
IMAGE_SCN_ALIGN_16BYTES = 0x00500000 // Default alignment if no others are specified.
|
||||
IMAGE_SCN_ALIGN_32BYTES = 0x00600000 //
|
||||
IMAGE_SCN_ALIGN_64BYTES = 0x00700000 //
|
||||
IMAGE_SCN_ALIGN_128BYTES = 0x00800000 //
|
||||
IMAGE_SCN_ALIGN_256BYTES = 0x00900000 //
|
||||
IMAGE_SCN_ALIGN_512BYTES = 0x00A00000 //
|
||||
IMAGE_SCN_ALIGN_1024BYTES = 0x00B00000 //
|
||||
IMAGE_SCN_ALIGN_2048BYTES = 0x00C00000 //
|
||||
IMAGE_SCN_ALIGN_4096BYTES = 0x00D00000 //
|
||||
IMAGE_SCN_ALIGN_8192BYTES = 0x00E00000 //
|
||||
IMAGE_SCN_ALIGN_MASK = 0x00F00000
|
||||
|
||||
IMAGE_SCN_LNK_NRELOC_OVFL = 0x01000000 // Section contains extended relocations.
|
||||
IMAGE_SCN_MEM_DISCARDABLE = 0x02000000 // Section can be discarded.
|
||||
IMAGE_SCN_MEM_NOT_CACHED = 0x04000000 // Section is not cachable.
|
||||
IMAGE_SCN_MEM_NOT_PAGED = 0x08000000 // Section is not pageable.
|
||||
IMAGE_SCN_MEM_SHARED = 0x10000000 // Section is shareable.
|
||||
IMAGE_SCN_MEM_EXECUTE = 0x20000000 // Section is executable.
|
||||
IMAGE_SCN_MEM_READ = 0x40000000 // Section is readable.
|
||||
IMAGE_SCN_MEM_WRITE = 0x80000000 // Section is writeable.
|
||||
|
||||
// TLS Characteristic Flags
|
||||
IMAGE_SCN_SCALE_INDEX = 0x00000001 // Tls index is scaled.
|
||||
)
|
||||
|
||||
// Based relocation format
|
||||
type IMAGE_BASE_RELOCATION struct {
|
||||
VirtualAddress uint32
|
||||
SizeOfBlock uint32
|
||||
}
|
||||
|
||||
const (
|
||||
IMAGE_REL_BASED_ABSOLUTE = 0
|
||||
IMAGE_REL_BASED_HIGH = 1
|
||||
IMAGE_REL_BASED_LOW = 2
|
||||
IMAGE_REL_BASED_HIGHLOW = 3
|
||||
IMAGE_REL_BASED_HIGHADJ = 4
|
||||
IMAGE_REL_BASED_MACHINE_SPECIFIC_5 = 5
|
||||
IMAGE_REL_BASED_RESERVED = 6
|
||||
IMAGE_REL_BASED_MACHINE_SPECIFIC_7 = 7
|
||||
IMAGE_REL_BASED_MACHINE_SPECIFIC_8 = 8
|
||||
IMAGE_REL_BASED_MACHINE_SPECIFIC_9 = 9
|
||||
IMAGE_REL_BASED_DIR64 = 10
|
||||
|
||||
IMAGE_REL_BASED_IA64_IMM64 = 9
|
||||
|
||||
IMAGE_REL_BASED_MIPS_JMPADDR = 5
|
||||
IMAGE_REL_BASED_MIPS_JMPADDR16 = 9
|
||||
|
||||
IMAGE_REL_BASED_ARM_MOV32 = 5
|
||||
IMAGE_REL_BASED_THUMB_MOV32 = 7
|
||||
)
|
||||
|
||||
// Export Format
|
||||
type IMAGE_EXPORT_DIRECTORY struct {
|
||||
Characteristics uint32
|
||||
TimeDateStamp uint32
|
||||
MajorVersion uint16
|
||||
MinorVersion uint16
|
||||
Name uint32
|
||||
Base uint32
|
||||
NumberOfFunctions uint32
|
||||
NumberOfNames uint32
|
||||
AddressOfFunctions uint32 // RVA from base of image
|
||||
AddressOfNames uint32 // RVA from base of image
|
||||
AddressOfNameOrdinals uint32 // RVA from base of image
|
||||
}
|
||||
|
||||
type IMAGE_IMPORT_BY_NAME struct {
|
||||
Hint uint16
|
||||
Name [1]byte
|
||||
}
|
||||
|
||||
func IMAGE_ORDINAL(ordinal uintptr) uintptr {
|
||||
return ordinal & 0xffff
|
||||
}
|
||||
|
||||
func IMAGE_SNAP_BY_ORDINAL(ordinal uintptr) bool {
|
||||
return (ordinal & IMAGE_ORDINAL_FLAG) != 0
|
||||
}
|
||||
|
||||
// Thread Local Storage
|
||||
type IMAGE_TLS_DIRECTORY struct {
|
||||
StartAddressOfRawData uintptr
|
||||
EndAddressOfRawData uintptr
|
||||
AddressOfIndex uintptr // PDWORD
|
||||
AddressOfCallbacks uintptr // PIMAGE_TLS_CALLBACK *;
|
||||
SizeOfZeroFill uint32
|
||||
Characteristics uint32
|
||||
}
|
||||
|
||||
type IMAGE_IMPORT_DESCRIPTOR struct {
|
||||
characteristicsOrOriginalFirstThunk uint32 // 0 for terminating null import descriptor
|
||||
// RVA to original unbound IAT (PIMAGE_THUNK_DATA)
|
||||
TimeDateStamp uint32 // 0 if not bound,
|
||||
// -1 if bound, and real date\time stamp
|
||||
// in IMAGE_DIRECTORY_ENTRY_BOUND_IMPORT (new BIND)
|
||||
// O.W. date/time stamp of DLL bound to (Old BIND)
|
||||
ForwarderChain uint32 // -1 if no forwarders
|
||||
Name uint32
|
||||
FirstThunk uint32 // RVA to IAT (if bound this IAT has actual addresses)
|
||||
}
|
||||
|
||||
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) Characteristics() uint32 {
|
||||
return imgimpdesc.characteristicsOrOriginalFirstThunk
|
||||
}
|
||||
|
||||
func (imgimpdesc *IMAGE_IMPORT_DESCRIPTOR) OriginalFirstThunk() uint32 {
|
||||
return imgimpdesc.characteristicsOrOriginalFirstThunk
|
||||
}
|
||||
|
||||
const (
|
||||
DLL_PROCESS_ATTACH = 1
|
||||
DLL_THREAD_ATTACH = 2
|
||||
DLL_THREAD_DETACH = 3
|
||||
DLL_PROCESS_DETACH = 0
|
||||
)
|
||||
|
||||
//sys loadLibraryA(libFileName *byte) (module windows.Handle, err error) = kernel32.LoadLibraryA
|
||||
//sys getProcAddress(module windows.Handle, procName *byte) (addr uintptr, err error) = kernel32.GetProcAddress
|
||||
//sys isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) = kernel32.IsBadReadPtr
|
||||
|
||||
type SYSTEM_INFO struct {
|
||||
ProcessorArchitecture uint16
|
||||
Reserved uint16
|
||||
PageSize uint32
|
||||
MinimumApplicationAddress uintptr
|
||||
MaximumApplicationAddress uintptr
|
||||
ActiveProcessorMask uintptr
|
||||
NumberOfProcessors uint32
|
||||
ProcessorType uint32
|
||||
AllocationGranularity uint32
|
||||
ProcessorLevel uint16
|
||||
ProcessorRevision uint16
|
||||
}
|
||||
45
tun/wintun/memmod/syscall_windows_32.go
Normal file
45
tun/wintun/memmod/syscall_windows_32.go
Normal file
@@ -0,0 +1,45 @@
|
||||
// +build 386 arm
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
// Optional header format
|
||||
type IMAGE_OPTIONAL_HEADER struct {
|
||||
Magic uint16
|
||||
MajorLinkerVersion uint8
|
||||
MinorLinkerVersion uint8
|
||||
SizeOfCode uint32
|
||||
SizeOfInitializedData uint32
|
||||
SizeOfUninitializedData uint32
|
||||
AddressOfEntryPoint uint32
|
||||
BaseOfCode uint32
|
||||
BaseOfData uint32
|
||||
ImageBase uintptr
|
||||
SectionAlignment uint32
|
||||
FileAlignment uint32
|
||||
MajorOperatingSystemVersion uint16
|
||||
MinorOperatingSystemVersion uint16
|
||||
MajorImageVersion uint16
|
||||
MinorImageVersion uint16
|
||||
MajorSubsystemVersion uint16
|
||||
MinorSubsystemVersion uint16
|
||||
Win32VersionValue uint32
|
||||
SizeOfImage uint32
|
||||
SizeOfHeaders uint32
|
||||
CheckSum uint32
|
||||
Subsystem uint16
|
||||
DllCharacteristics uint16
|
||||
SizeOfStackReserve uintptr
|
||||
SizeOfStackCommit uintptr
|
||||
SizeOfHeapReserve uintptr
|
||||
SizeOfHeapCommit uintptr
|
||||
LoaderFlags uint32
|
||||
NumberOfRvaAndSizes uint32
|
||||
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
|
||||
}
|
||||
|
||||
const IMAGE_ORDINAL_FLAG uintptr = 0x80000000
|
||||
44
tun/wintun/memmod/syscall_windows_64.go
Normal file
44
tun/wintun/memmod/syscall_windows_64.go
Normal file
@@ -0,0 +1,44 @@
|
||||
// +build amd64 arm64
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2020 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package memmod
|
||||
|
||||
// Optional header format
|
||||
type IMAGE_OPTIONAL_HEADER struct {
|
||||
Magic uint16
|
||||
MajorLinkerVersion uint8
|
||||
MinorLinkerVersion uint8
|
||||
SizeOfCode uint32
|
||||
SizeOfInitializedData uint32
|
||||
SizeOfUninitializedData uint32
|
||||
AddressOfEntryPoint uint32
|
||||
BaseOfCode uint32
|
||||
ImageBase uintptr
|
||||
SectionAlignment uint32
|
||||
FileAlignment uint32
|
||||
MajorOperatingSystemVersion uint16
|
||||
MinorOperatingSystemVersion uint16
|
||||
MajorImageVersion uint16
|
||||
MinorImageVersion uint16
|
||||
MajorSubsystemVersion uint16
|
||||
MinorSubsystemVersion uint16
|
||||
Win32VersionValue uint32
|
||||
SizeOfImage uint32
|
||||
SizeOfHeaders uint32
|
||||
CheckSum uint32
|
||||
Subsystem uint16
|
||||
DllCharacteristics uint16
|
||||
SizeOfStackReserve uintptr
|
||||
SizeOfStackCommit uintptr
|
||||
SizeOfHeapReserve uintptr
|
||||
SizeOfHeapCommit uintptr
|
||||
LoaderFlags uint32
|
||||
NumberOfRvaAndSizes uint32
|
||||
DataDirectory [IMAGE_NUMBEROF_DIRECTORY_ENTRIES]IMAGE_DATA_DIRECTORY
|
||||
}
|
||||
|
||||
const IMAGE_ORDINAL_FLAG uintptr = 0x8000000000000000
|
||||
70
tun/wintun/memmod/zsyscall_windows.go
Normal file
70
tun/wintun/memmod/zsyscall_windows.go
Normal file
@@ -0,0 +1,70 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package memmod
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
errERROR_EINVAL error = syscall.EINVAL
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return errERROR_EINVAL
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
|
||||
procGetProcAddress = modkernel32.NewProc("GetProcAddress")
|
||||
procIsBadReadPtr = modkernel32.NewProc("IsBadReadPtr")
|
||||
procLoadLibraryA = modkernel32.NewProc("LoadLibraryA")
|
||||
)
|
||||
|
||||
func getProcAddress(module windows.Handle, procName *byte) (addr uintptr, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procGetProcAddress.Addr(), 2, uintptr(module), uintptr(unsafe.Pointer(procName)), 0)
|
||||
addr = uintptr(r0)
|
||||
if addr == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func isBadReadPtr(addr uintptr, ucb uintptr) (ret bool) {
|
||||
r0, _, _ := syscall.Syscall(procIsBadReadPtr.Addr(), 2, uintptr(addr), uintptr(ucb), 0)
|
||||
ret = r0 != 0
|
||||
return
|
||||
}
|
||||
|
||||
func loadLibraryA(libFileName *byte) (module windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procLoadLibraryA.Addr(), 1, uintptr(unsafe.Pointer(libFileName)), 0, 0)
|
||||
module = windows.Handle(r0)
|
||||
if module == 0 {
|
||||
err = errnoErr(e1)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,98 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package wintun
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/crypto/blake2s"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/text/unicode/norm"
|
||||
|
||||
"golang.zx2c4.com/wireguard/tun/wintun/namespaceapi"
|
||||
)
|
||||
|
||||
var (
|
||||
wintunObjectSecurityAttributes *windows.SecurityAttributes
|
||||
hasInitializedNamespace bool
|
||||
initializingNamespace sync.Mutex
|
||||
)
|
||||
|
||||
func initializeNamespace() error {
|
||||
initializingNamespace.Lock()
|
||||
defer initializingNamespace.Unlock()
|
||||
if hasInitializedNamespace {
|
||||
return nil
|
||||
}
|
||||
sd, err := windows.SecurityDescriptorFromString("O:SYD:P(A;;GA;;;SY)")
|
||||
if err != nil {
|
||||
return fmt.Errorf("SddlToSecurityDescriptor failed: %v", err)
|
||||
}
|
||||
wintunObjectSecurityAttributes = &windows.SecurityAttributes{
|
||||
Length: uint32(unsafe.Sizeof(windows.SecurityAttributes{})),
|
||||
SecurityDescriptor: sd,
|
||||
}
|
||||
sid, err := windows.CreateWellKnownSid(windows.WinLocalSystemSid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateWellKnownSid(LOCAL_SYSTEM) failed: %v", err)
|
||||
}
|
||||
|
||||
boundary, err := namespaceapi.CreateBoundaryDescriptor("Wintun")
|
||||
if err != nil {
|
||||
return fmt.Errorf("CreateBoundaryDescriptor failed: %v", err)
|
||||
}
|
||||
err = boundary.AddSid(sid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("AddSIDToBoundaryDescriptor failed: %v", err)
|
||||
}
|
||||
for {
|
||||
_, err = namespaceapi.CreatePrivateNamespace(wintunObjectSecurityAttributes, boundary, "Wintun")
|
||||
if err == windows.ERROR_ALREADY_EXISTS {
|
||||
_, err = namespaceapi.OpenPrivateNamespace(boundary, "Wintun")
|
||||
if err == windows.ERROR_PATH_NOT_FOUND {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("Create/OpenPrivateNamespace failed: %v", err)
|
||||
}
|
||||
break
|
||||
}
|
||||
hasInitializedNamespace = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pool Pool) takeNameMutex() (windows.Handle, error) {
|
||||
err := initializeNamespace()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
const mutexLabel = "WireGuard Adapter Name Mutex Stable Suffix v1 jason@zx2c4.com"
|
||||
b2, _ := blake2s.New256(nil)
|
||||
b2.Write([]byte(mutexLabel))
|
||||
b2.Write(norm.NFC.Bytes([]byte(string(pool))))
|
||||
mutexName := `Wintun\Wintun-Name-Mutex-` + hex.EncodeToString(b2.Sum(nil))
|
||||
mutex, err := windows.CreateMutex(wintunObjectSecurityAttributes, false, windows.StringToUTF16Ptr(mutexName))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("Error creating name mutex: %v", err)
|
||||
return 0, err
|
||||
}
|
||||
event, err := windows.WaitForSingleObject(mutex, windows.INFINITE)
|
||||
if err != nil {
|
||||
windows.CloseHandle(mutex)
|
||||
return 0, fmt.Errorf("Error waiting on name mutex: %v", err)
|
||||
}
|
||||
if event != windows.WAIT_OBJECT_0 && event != windows.WAIT_ABANDONED {
|
||||
windows.CloseHandle(mutex)
|
||||
return 0, errors.New("Error with event trigger of name mutex")
|
||||
}
|
||||
return mutex, nil
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package namespaceapi
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go namespaceapi_windows.go
|
||||
@@ -1,83 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package namespaceapi
|
||||
|
||||
import "golang.org/x/sys/windows"
|
||||
|
||||
//sys createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) = kernel32.CreateBoundaryDescriptorW
|
||||
//sys deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) = kernel32.DeleteBoundaryDescriptor
|
||||
//sys addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) = kernel32.AddSIDToBoundaryDescriptor
|
||||
//sys createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.CreatePrivateNamespaceW
|
||||
//sys openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) = kernel32.OpenPrivateNamespaceW
|
||||
//sys closePrivateNamespace(handle windows.Handle, flags uint32) (err error) = kernel32.ClosePrivateNamespace
|
||||
|
||||
// BoundaryDescriptor represents a boundary that defines how the objects in the namespace are to be isolated.
|
||||
type BoundaryDescriptor windows.Handle
|
||||
|
||||
// CreateBoundaryDescriptor creates a boundary descriptor.
|
||||
func CreateBoundaryDescriptor(name string) (BoundaryDescriptor, error) {
|
||||
name16, err := windows.UTF16PtrFromString(name)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
handle, err := createBoundaryDescriptor(name16, 0)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return BoundaryDescriptor(handle), nil
|
||||
}
|
||||
|
||||
// Delete deletes the specified boundary descriptor.
|
||||
func (bd BoundaryDescriptor) Delete() {
|
||||
deleteBoundaryDescriptor(windows.Handle(bd))
|
||||
}
|
||||
|
||||
// AddSid adds a security identifier (SID) to the specified boundary descriptor.
|
||||
func (bd *BoundaryDescriptor) AddSid(requiredSid *windows.SID) error {
|
||||
return addSIDToBoundaryDescriptor((*windows.Handle)(bd), requiredSid)
|
||||
}
|
||||
|
||||
// PrivateNamespace represents a private namespace.
|
||||
type PrivateNamespace windows.Handle
|
||||
|
||||
// CreatePrivateNamespace creates a private namespace.
|
||||
func CreatePrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
|
||||
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
handle, err := createPrivateNamespace(privateNamespaceAttributes, windows.Handle(boundaryDescriptor), aliasPrefix16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return PrivateNamespace(handle), nil
|
||||
}
|
||||
|
||||
// OpenPrivateNamespace opens a private namespace.
|
||||
func OpenPrivateNamespace(boundaryDescriptor BoundaryDescriptor, aliasPrefix string) (PrivateNamespace, error) {
|
||||
aliasPrefix16, err := windows.UTF16PtrFromString(aliasPrefix)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
handle, err := openPrivateNamespace(windows.Handle(boundaryDescriptor), aliasPrefix16)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return PrivateNamespace(handle), nil
|
||||
}
|
||||
|
||||
// ClosePrivateNamespaceFlags describes flags that are used by PrivateNamespace's Close() method.
|
||||
type ClosePrivateNamespaceFlags uint32
|
||||
|
||||
const (
|
||||
// PrivateNamespaceFlagDestroy makes the close to destroy the namespace.
|
||||
PrivateNamespaceFlagDestroy = ClosePrivateNamespaceFlags(0x1)
|
||||
)
|
||||
|
||||
// Close closes an open namespace handle.
|
||||
func (pns PrivateNamespace) Close(flags ClosePrivateNamespaceFlags) error {
|
||||
return closePrivateNamespace(windows.Handle(pns), uint32(flags))
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
// Code generated by 'go generate'; DO NOT EDIT.
|
||||
|
||||
package namespaceapi
|
||||
|
||||
import (
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var _ unsafe.Pointer
|
||||
|
||||
// Do the interface allocations only once for common
|
||||
// Errno values.
|
||||
const (
|
||||
errnoERROR_IO_PENDING = 997
|
||||
)
|
||||
|
||||
var (
|
||||
errERROR_IO_PENDING error = syscall.Errno(errnoERROR_IO_PENDING)
|
||||
)
|
||||
|
||||
// errnoErr returns common boxed Errno values, to prevent
|
||||
// allocations at runtime.
|
||||
func errnoErr(e syscall.Errno) error {
|
||||
switch e {
|
||||
case 0:
|
||||
return nil
|
||||
case errnoERROR_IO_PENDING:
|
||||
return errERROR_IO_PENDING
|
||||
}
|
||||
// TODO: add more here, after collecting data on the common
|
||||
// error values see on Windows. (perhaps when running
|
||||
// all.bat?)
|
||||
return e
|
||||
}
|
||||
|
||||
var (
|
||||
modkernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
|
||||
procCreateBoundaryDescriptorW = modkernel32.NewProc("CreateBoundaryDescriptorW")
|
||||
procDeleteBoundaryDescriptor = modkernel32.NewProc("DeleteBoundaryDescriptor")
|
||||
procAddSIDToBoundaryDescriptor = modkernel32.NewProc("AddSIDToBoundaryDescriptor")
|
||||
procCreatePrivateNamespaceW = modkernel32.NewProc("CreatePrivateNamespaceW")
|
||||
procOpenPrivateNamespaceW = modkernel32.NewProc("OpenPrivateNamespaceW")
|
||||
procClosePrivateNamespace = modkernel32.NewProc("ClosePrivateNamespace")
|
||||
)
|
||||
|
||||
func createBoundaryDescriptor(name *uint16, flags uint32) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procCreateBoundaryDescriptorW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(flags), 0)
|
||||
handle = windows.Handle(r0)
|
||||
if handle == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func deleteBoundaryDescriptor(boundaryDescriptor windows.Handle) {
|
||||
syscall.Syscall(procDeleteBoundaryDescriptor.Addr(), 1, uintptr(boundaryDescriptor), 0, 0)
|
||||
return
|
||||
}
|
||||
|
||||
func addSIDToBoundaryDescriptor(boundaryDescriptor *windows.Handle, requiredSid *windows.SID) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procAddSIDToBoundaryDescriptor.Addr(), 2, uintptr(unsafe.Pointer(boundaryDescriptor)), uintptr(unsafe.Pointer(requiredSid)), 0)
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func createPrivateNamespace(privateNamespaceAttributes *windows.SecurityAttributes, boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procCreatePrivateNamespaceW.Addr(), 3, uintptr(unsafe.Pointer(privateNamespaceAttributes)), uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)))
|
||||
handle = windows.Handle(r0)
|
||||
if handle == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func openPrivateNamespace(boundaryDescriptor windows.Handle, aliasPrefix *uint16) (handle windows.Handle, err error) {
|
||||
r0, _, e1 := syscall.Syscall(procOpenPrivateNamespaceW.Addr(), 2, uintptr(boundaryDescriptor), uintptr(unsafe.Pointer(aliasPrefix)), 0)
|
||||
handle = windows.Handle(r0)
|
||||
if handle == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func closePrivateNamespace(handle windows.Handle, flags uint32) (err error) {
|
||||
r1, _, e1 := syscall.Syscall(procClosePrivateNamespace.Addr(), 2, uintptr(handle), uintptr(flags), 0)
|
||||
if r1 == 0 {
|
||||
if e1 != 0 {
|
||||
err = errnoErr(e1)
|
||||
} else {
|
||||
err = syscall.EINVAL
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -1,8 +0,0 @@
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2019 WireGuard LLC. All Rights Reserved.
|
||||
*/
|
||||
|
||||
package nci
|
||||
|
||||
//go:generate go run golang.org/x/sys/windows/mkwinsyscall -output zsyscall_windows.go nci_windows.go
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user