global: begin modularization

This commit is contained in:
Jason A. Donenfeld
2019-03-03 04:04:41 +01:00
parent d435be35ca
commit 69f0fe67b6
44 changed files with 118 additions and 109 deletions

251
device/allowedips.go Normal file
View File

@@ -0,0 +1,251 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"math/bits"
"net"
"sync"
"unsafe"
)
type trieEntry struct {
cidr uint
child [2]*trieEntry
bits net.IP
peer *Peer
// index of "branching" bit
bit_at_byte uint
bit_at_shift uint
}
func isLittleEndian() bool {
one := uint32(1)
return *(*byte)(unsafe.Pointer(&one)) != 0
}
func swapU32(i uint32) uint32 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes32(i)
}
func swapU64(i uint64) uint64 {
if !isLittleEndian() {
return i
}
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
b := (*uint32)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
return uint(bits.LeadingZeros32(swapU32(x)))
} else if size == net.IPv6len {
a := (*uint64)(unsafe.Pointer(&ip1[0]))
b := (*uint64)(unsafe.Pointer(&ip2[0]))
x := *a ^ *b
if x != 0 {
return uint(bits.LeadingZeros64(swapU64(x)))
}
a = (*uint64)(unsafe.Pointer(&ip1[8]))
b = (*uint64)(unsafe.Pointer(&ip2[8]))
x = *a ^ *b
return 64 + uint(bits.LeadingZeros64(swapU64(x)))
} else {
panic("Wrong size bit string")
}
}
func (node *trieEntry) removeByPeer(p *Peer) *trieEntry {
if node == nil {
return node
}
// walk recursively
node.child[0] = node.child[0].removeByPeer(p)
node.child[1] = node.child[1].removeByPeer(p)
if node.peer != p {
return node
}
// remove peer & merge
node.peer = nil
if node.child[0] == nil {
return node.child[1]
}
return node.child[0]
}
func (node *trieEntry) choose(ip net.IP) byte {
return (ip[node.bit_at_byte] >> node.bit_at_shift) & 1
}
func (node *trieEntry) insert(ip net.IP, cidr uint, peer *Peer) *trieEntry {
// at leaf
if node == nil {
return &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
}
// traverse deeper
common := commonBits(node.bits, ip)
if node.cidr <= cidr && common >= node.cidr {
if node.cidr == cidr {
node.peer = peer
return node
}
bit := node.choose(ip)
node.child[bit] = node.child[bit].insert(ip, cidr, peer)
return node
}
// split node
newNode := &trieEntry{
bits: ip,
peer: peer,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
cidr = min(cidr, common)
// check for shorter prefix
if newNode.cidr == cidr {
bit := newNode.choose(node.bits)
newNode.child[bit] = node
return newNode
}
// create new parent for node & newNode
parent := &trieEntry{
bits: ip,
peer: nil,
cidr: cidr,
bit_at_byte: cidr / 8,
bit_at_shift: 7 - (cidr % 8),
}
bit := parent.choose(ip)
parent.child[bit] = newNode
parent.child[bit^1] = node
return parent
}
func (node *trieEntry) lookup(ip net.IP) *Peer {
var found *Peer
size := uint(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
if node.peer != nil {
found = node.peer
}
if node.bit_at_byte == size {
break
}
bit := node.choose(ip)
node = node.child[bit]
}
return found
}
func (node *trieEntry) entriesForPeer(p *Peer, results []net.IPNet) []net.IPNet {
if node == nil {
return results
}
if node.peer == p {
mask := net.CIDRMask(int(node.cidr), len(node.bits)*8)
results = append(results, net.IPNet{
Mask: mask,
IP: node.bits.Mask(mask),
})
}
results = node.child[0].entriesForPeer(p, results)
results = node.child[1].entriesForPeer(p, results)
return results
}
type AllowedIPs struct {
IPv4 *trieEntry
IPv6 *trieEntry
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer) []net.IPNet {
table.mutex.RLock()
defer table.mutex.RUnlock()
allowed := make([]net.IPNet, 0, 10)
allowed = table.IPv4.entriesForPeer(peer, allowed)
allowed = table.IPv6.entriesForPeer(peer, allowed)
return allowed
}
func (table *AllowedIPs) Reset() {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = nil
table.IPv6 = nil
}
func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
table.IPv4 = table.IPv4.removeByPeer(peer)
table.IPv6 = table.IPv6.removeByPeer(peer)
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
switch len(ip) {
case net.IPv6len:
table.IPv6 = table.IPv6.insert(ip, cidr, peer)
case net.IPv4len:
table.IPv4 = table.IPv4.insert(ip, cidr, peer)
default:
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) LookupIPv4(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv4.lookup(address)
}
func (table *AllowedIPs) LookupIPv6(address []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
return table.IPv6.lookup(address)
}

View File

@@ -0,0 +1,131 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"sort"
"testing"
)
const (
NumberOfPeers = 100
NumberOfAddresses = 250
NumberOfTests = 10000
)
type SlowNode struct {
peer *Peer
cidr uint
bits []byte
}
type SlowRouter []*SlowNode
func (r SlowRouter) Len() int {
return len(r)
}
func (r SlowRouter) Less(i, j int) bool {
return r[i].cidr > r[j].cidr
}
func (r SlowRouter) Swap(i, j int) {
r[i], r[j] = r[j], r[i]
}
func (r SlowRouter) Insert(addr []byte, cidr uint, peer *Peer) SlowRouter {
for _, t := range r {
if t.cidr == cidr && commonBits(t.bits, addr) >= cidr {
t.peer = peer
t.bits = addr
return r
}
}
r = append(r, &SlowNode{
cidr: cidr,
bits: addr,
peer: peer,
})
sort.Sort(r)
return r
}
func (r SlowRouter) Lookup(addr []byte) *Peer {
for _, t := range r {
common := commonBits(t.bits, addr)
if common >= t.cidr {
return t.peer
}
}
return nil
}
func TestTrieRandomIPv4(t *testing.T) {
var trie *trieEntry
var slow SlowRouter
var peers []*Peer
rand.Seed(1)
const AddressLength = 4
for n := 0; n < NumberOfPeers; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
}
}
}
func TestTrieRandomIPv6(t *testing.T) {
var trie *trieEntry
var slow SlowRouter
var peers []*Peer
rand.Seed(1)
const AddressLength = 16
for n := 0; n < NumberOfPeers; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < NumberOfAddresses; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % NumberOfPeers
trie = trie.insert(addr[:], cidr, peers[index])
slow = slow.Insert(addr[:], cidr, peers[index])
}
for n := 0; n < NumberOfTests; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
peer1 := slow.Lookup(addr[:])
peer2 := trie.lookup(addr[:])
if peer1 != peer2 {
t.Error("Trie did not match naive implementation, for:", addr)
}
}
}

260
device/allowedips_test.go Normal file
View File

@@ -0,0 +1,260 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"net"
"testing"
)
/* Todo: More comprehensive
*/
type testPairCommonBits struct {
s1 []byte
s2 []byte
match uint
}
type testPairTrieInsert struct {
key []byte
cidr uint
peer *Peer
}
type testPairTrieLookup struct {
key []byte
peer *Peer
}
func printTrie(t *testing.T, p *trieEntry) {
if p == nil {
return
}
t.Log(p)
printTrie(t, p.child[0])
printTrie(t, p.child[1])
}
func TestCommonBits(t *testing.T) {
tests := []testPairCommonBits{
{s1: []byte{1, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 7},
{s1: []byte{0, 4, 53, 128}, s2: []byte{0, 0, 0, 0}, match: 13},
{s1: []byte{0, 4, 53, 253}, s2: []byte{0, 4, 53, 252}, match: 31},
{s1: []byte{192, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 15},
{s1: []byte{65, 168, 1, 1}, s2: []byte{192, 169, 1, 1}, match: 0},
}
for _, p := range tests {
v := commonBits(p.s1, p.s2)
if v != p.match {
t.Error(
"For slice", p.s1, p.s2,
"expected match", p.match,
",but got", v,
)
}
}
}
func benchmarkTrie(peerNumber int, addressNumber int, addressLength int, b *testing.B) {
var trie *trieEntry
var peers []*Peer
rand.Seed(1)
const AddressLength = 4
for n := 0; n < peerNumber; n += 1 {
peers = append(peers, &Peer{})
}
for n := 0; n < addressNumber; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
cidr := uint(rand.Uint32() % (AddressLength * 8))
index := rand.Int() % peerNumber
trie = trie.insert(addr[:], cidr, peers[index])
}
for n := 0; n < b.N; n += 1 {
var addr [AddressLength]byte
rand.Read(addr[:])
trie.lookup(addr[:])
}
}
func BenchmarkTrieIPv4Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv4len, b)
}
func BenchmarkTrieIPv4Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv4len, b)
}
func BenchmarkTrieIPv6Peers100Addresses1000(b *testing.B) {
benchmarkTrie(100, 1000, net.IPv6len, b)
}
func BenchmarkTrieIPv6Peers10Addresses10(b *testing.B) {
benchmarkTrie(10, 10, net.IPv6len, b)
}
/* Test ported from kernel implementation:
* selftest/allowedips.h
*/
func TestTrieIPv4(t *testing.T) {
a := &Peer{}
b := &Peer{}
c := &Peer{}
d := &Peer{}
e := &Peer{}
g := &Peer{}
h := &Peer{}
var trie *trieEntry
insert := func(peer *Peer, a, b, c, d byte, cidr uint) {
trie = trie.insert([]byte{a, b, c, d}, cidr, peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d})
if p != peer {
t.Error("Assert EQ failed")
}
}
assertNEQ := func(peer *Peer, a, b, c, d byte) {
p := trie.lookup([]byte{a, b, c, d})
if p == peer {
t.Error("Assert NEQ failed")
}
}
insert(a, 192, 168, 4, 0, 24)
insert(b, 192, 168, 4, 4, 32)
insert(c, 192, 168, 0, 0, 16)
insert(d, 192, 95, 5, 64, 27)
insert(c, 192, 95, 5, 65, 27)
insert(e, 0, 0, 0, 0, 0)
insert(g, 64, 15, 112, 0, 20)
insert(h, 64, 15, 123, 211, 25)
insert(a, 10, 0, 0, 0, 25)
insert(b, 10, 0, 0, 128, 25)
insert(a, 10, 1, 0, 0, 30)
insert(b, 10, 1, 0, 4, 30)
insert(c, 10, 1, 0, 8, 29)
insert(d, 10, 1, 0, 16, 29)
assertEQ(a, 192, 168, 4, 20)
assertEQ(a, 192, 168, 4, 0)
assertEQ(b, 192, 168, 4, 4)
assertEQ(c, 192, 168, 200, 182)
assertEQ(c, 192, 95, 5, 68)
assertEQ(e, 192, 95, 5, 96)
assertEQ(g, 64, 15, 116, 26)
assertEQ(g, 64, 15, 127, 3)
insert(a, 1, 0, 0, 0, 32)
insert(a, 64, 0, 0, 0, 32)
insert(a, 128, 0, 0, 0, 32)
insert(a, 192, 0, 0, 0, 32)
insert(a, 255, 0, 0, 0, 32)
assertEQ(a, 1, 0, 0, 0)
assertEQ(a, 64, 0, 0, 0)
assertEQ(a, 128, 0, 0, 0)
assertEQ(a, 192, 0, 0, 0)
assertEQ(a, 255, 0, 0, 0)
trie = trie.removeByPeer(a)
assertNEQ(a, 1, 0, 0, 0)
assertNEQ(a, 64, 0, 0, 0)
assertNEQ(a, 128, 0, 0, 0)
assertNEQ(a, 192, 0, 0, 0)
assertNEQ(a, 255, 0, 0, 0)
trie = nil
insert(a, 192, 168, 0, 0, 16)
insert(a, 192, 168, 0, 0, 24)
trie = trie.removeByPeer(a)
assertNEQ(a, 192, 168, 0, 1)
}
/* Test ported from kernel implementation:
* selftest/allowedips.h
*/
func TestTrieIPv6(t *testing.T) {
a := &Peer{}
b := &Peer{}
c := &Peer{}
d := &Peer{}
e := &Peer{}
f := &Peer{}
g := &Peer{}
h := &Peer{}
var trie *trieEntry
expand := func(a uint32) []byte {
var out [4]byte
out[0] = byte(a >> 24 & 0xff)
out[1] = byte(a >> 16 & 0xff)
out[2] = byte(a >> 8 & 0xff)
out[3] = byte(a & 0xff)
return out[:]
}
insert := func(peer *Peer, a, b, c, d uint32, cidr uint) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
trie = trie.insert(addr, cidr, peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {
var addr []byte
addr = append(addr, expand(a)...)
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
p := trie.lookup(addr)
if p != peer {
t.Error("Assert EQ failed")
}
}
insert(d, 0x26075300, 0x60006b00, 0, 0xc05f0543, 128)
insert(c, 0x26075300, 0x60006b00, 0, 0, 64)
insert(e, 0, 0, 0, 0, 0)
insert(f, 0, 0, 0, 0, 0)
insert(g, 0x24046800, 0, 0, 0, 32)
insert(h, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 64)
insert(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef, 128)
insert(c, 0x24446800, 0x40e40800, 0xdeaebeef, 0xdefbeef, 128)
insert(b, 0x24446800, 0xf0e40800, 0xeeaebeef, 0, 98)
assertEQ(d, 0x26075300, 0x60006b00, 0, 0xc05f0543)
assertEQ(c, 0x26075300, 0x60006b00, 0, 0xc02e01ee)
assertEQ(f, 0x26075300, 0x60006b01, 0, 0)
assertEQ(g, 0x24046800, 0x40040806, 0, 0x1006)
assertEQ(g, 0x24046800, 0x40040806, 0x1234, 0x5678)
assertEQ(f, 0x240467ff, 0x40040806, 0x1234, 0x5678)
assertEQ(f, 0x24046801, 0x40040806, 0x1234, 0x5678)
assertEQ(h, 0x24046800, 0x40040800, 0x1234, 0x5678)
assertEQ(h, 0x24046800, 0x40040800, 0, 0)
assertEQ(h, 0x24046800, 0x40040800, 0x10101010, 0x10101010)
assertEQ(a, 0x24046800, 0x40040800, 0xdeadbeef, 0xdeadbeef)
}

55
device/bind_test.go Normal file
View File

@@ -0,0 +1,55 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import "errors"
type DummyDatagram struct {
msg []byte
endpoint Endpoint
world bool // better type
}
type DummyBind struct {
in6 chan DummyDatagram
ou6 chan DummyDatagram
in4 chan DummyDatagram
ou4 chan DummyDatagram
closed bool
}
func (b *DummyBind) SetMark(v uint32) error {
return nil
}
func (b *DummyBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
datagram, ok := <-b.in6
if !ok {
return 0, nil, errors.New("closed")
}
copy(buff, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
func (b *DummyBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
datagram, ok := <-b.in4
if !ok {
return 0, nil, errors.New("closed")
}
copy(buff, datagram.msg)
return len(datagram.msg), datagram.endpoint, nil
}
func (b *DummyBind) Close() error {
close(b.in6)
close(b.in4)
b.closed = true
return nil
}
func (b *DummyBind) Send(buff []byte, end Endpoint) error {
return nil
}

180
device/conn.go Normal file
View File

@@ -0,0 +1,180 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
)
const (
ConnRoutineNumber = 2
)
/* A Bind handles listening on a port for both IPv6 and IPv4 UDP traffic
*/
type Bind interface {
SetMark(value uint32) error
ReceiveIPv6(buff []byte) (int, Endpoint, error)
ReceiveIPv4(buff []byte) (int, Endpoint, error)
Send(buff []byte, end Endpoint) error
Close() error
}
/* An Endpoint maintains the source/destination caching for a peer
*
* dst : the remote address of a peer ("endpoint" in uapi terminology)
* src : the local address from which datagrams originate going to the peer
*/
type Endpoint interface {
ClearSrc() // clears the source address
SrcToString() string // returns the local source address (ip:port)
DstToString() string // returns the destination address (ip:port)
DstToBytes() []byte // used for mac2 cookie calculations
DstIP() net.IP
SrcIP() net.IP
}
func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address
host, _, err := net.SplitHostPort(s)
if err != nil {
return nil, err
}
if ip := net.ParseIP(host); ip == nil {
return nil, errors.New("Failed to parse IP address: " + host)
}
// parse address and port
addr, err := net.ResolveUDPAddr("udp", s)
if err != nil {
return nil, err
}
ip4 := addr.IP.To4()
if ip4 != nil {
addr.IP = ip4
}
return addr, err
}
func unsafeCloseBind(device *Device) error {
var err error
netc := &device.net
if netc.bind != nil {
err = netc.bind.Close()
netc.bind = nil
}
netc.stopping.Wait()
return err
}
func (device *Device) BindSetMark(mark uint32) error {
device.net.Lock()
defer device.net.Unlock()
// check if modified
if device.net.fwmark == mark {
return nil
}
// update fwmark on existing bind
device.net.fwmark = mark
if device.isUp.Get() && device.net.bind != nil {
if err := device.net.bind.SetMark(mark); err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
return nil
}
func (device *Device) BindUpdate() error {
device.net.Lock()
defer device.net.Unlock()
// close existing sockets
if err := unsafeCloseBind(device); err != nil {
return err
}
// open new sockets
if device.isUp.Get() {
// bind to new port
var err error
netc := &device.net
netc.bind, netc.port, err = CreateBind(netc.port, device)
if err != nil {
netc.bind = nil
netc.port = 0
return err
}
// set fwmark
if netc.fwmark != 0 {
err = netc.bind.SetMark(netc.fwmark)
if err != nil {
return err
}
}
// clear cached source addresses
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Lock()
defer peer.Unlock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
}
device.peers.RUnlock()
// start receiving routines
device.net.starting.Add(ConnRoutineNumber)
device.net.stopping.Add(ConnRoutineNumber)
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
device.net.starting.Wait()
device.log.Debug.Println("UDP bind has been updated")
}
return nil
}
func (device *Device) BindClose() error {
device.net.Lock()
err := unsafeCloseBind(device)
device.net.Unlock()
return err
}

170
device/conn_default.go Normal file
View File

@@ -0,0 +1,170 @@
// +build !linux android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"net"
"os"
"syscall"
)
/* This code is meant to be a temporary solution
* on platforms for which the sticky socket / source caching behavior
* has not yet been implemented.
*
* See conn_linux.go for an implementation on the linux platform.
*/
type NativeBind struct {
ipv4 *net.UDPConn
ipv6 *net.UDPConn
}
type NativeEndpoint net.UDPAddr
var _ Bind = (*NativeBind)(nil)
var _ Endpoint = (*NativeEndpoint)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
addr, err := parseEndpoint(s)
return (*NativeEndpoint)(addr), err
}
func (_ *NativeEndpoint) ClearSrc() {}
func (e *NativeEndpoint) DstIP() net.IP {
return (*net.UDPAddr)(e).IP
}
func (e *NativeEndpoint) SrcIP() net.IP {
return nil // not supported
}
func (e *NativeEndpoint) DstToBytes() []byte {
addr := (*net.UDPAddr)(e)
out := addr.IP.To4()
if out == nil {
out = addr.IP
}
out = append(out, byte(addr.Port&0xff))
out = append(out, byte((addr.Port>>8)&0xff))
return out
}
func (e *NativeEndpoint) DstToString() string {
return (*net.UDPAddr)(e).String()
}
func (e *NativeEndpoint) SrcToString() string {
return ""
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
// listen
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
if err != nil {
return nil, 0, err
}
// retrieve port
laddr := conn.LocalAddr()
uaddr, err := net.ResolveUDPAddr(
laddr.Network(),
laddr.String(),
)
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
}
func extractErrno(err error) error {
opErr, ok := err.(*net.OpError)
if !ok {
return nil
}
syscallErr, ok := opErr.Err.(*os.SyscallError)
if !ok {
return nil
}
return syscallErr.Err
}
func CreateBind(uport uint16, device *Device) (Bind, uint16, error) {
var err error
var bind NativeBind
port := int(uport)
bind.ipv4, port, err = listenNet("udp4", port)
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
return nil, 0, err
}
bind.ipv6, port, err = listenNet("udp6", port)
if err != nil && extractErrno(err) != syscall.EAFNOSUPPORT {
bind.ipv4.Close()
bind.ipv4 = nil
return nil, 0, err
}
return &bind, uint16(port), nil
}
func (bind *NativeBind) Close() error {
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
}
if err1 != nil {
return err1
}
return err2
}
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
if bind.ipv4 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil {
endpoint.IP = endpoint.IP.To4()
}
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT
}
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
return n, (*NativeEndpoint)(endpoint), err
}
func (bind *NativeBind) Send(buff []byte, endpoint Endpoint) error {
var err error
nend := endpoint.(*NativeEndpoint)
if nend.IP.To4() != nil {
if bind.ipv4 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend))
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err
}

746
device/conn_linux.go Normal file
View File

@@ -0,0 +1,746 @@
// +build !android
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*
* This implements userspace semantics of "sticky sockets", modeled after
* WireGuard's kernelspace implementation. This is more or less a straight port
* of the sticky-sockets.c example code:
* https://git.zx2c4.com/WireGuard/tree/contrib/examples/sticky-sockets/sticky-sockets.c
*
* Currently there is no way to achieve this within the net package:
* See e.g. https://github.com/golang/go/issues/17930
* So this code is remains platform dependent.
*/
package device
import (
"errors"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/rwcancel"
"net"
"strconv"
"sync"
"syscall"
"unsafe"
)
const (
FD_ERR = -1
)
type IPv4Source struct {
src [4]byte
ifindex int32
}
type IPv6Source struct {
src [16]byte
//ifindex belongs in dst.ZoneId
}
type NativeEndpoint struct {
dst [unsafe.Sizeof(unix.SockaddrInet6{})]byte
src [unsafe.Sizeof(IPv6Source{})]byte
isV6 bool
}
func (endpoint *NativeEndpoint) src4() *IPv4Source {
return (*IPv4Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) src6() *IPv6Source {
return (*IPv6Source)(unsafe.Pointer(&endpoint.src[0]))
}
func (endpoint *NativeEndpoint) dst4() *unix.SockaddrInet4 {
return (*unix.SockaddrInet4)(unsafe.Pointer(&endpoint.dst[0]))
}
func (endpoint *NativeEndpoint) dst6() *unix.SockaddrInet6 {
return (*unix.SockaddrInet6)(unsafe.Pointer(&endpoint.dst[0]))
}
type NativeBind struct {
sock4 int
sock6 int
netlinkSock int
netlinkCancel *rwcancel.RWCancel
lastMark uint32
}
var _ Endpoint = (*NativeEndpoint)(nil)
var _ Bind = (*NativeBind)(nil)
func CreateEndpoint(s string) (Endpoint, error) {
var end NativeEndpoint
addr, err := parseEndpoint(s)
if err != nil {
return nil, err
}
ipv4 := addr.IP.To4()
if ipv4 != nil {
dst := end.dst4()
end.isV6 = false
dst.Port = addr.Port
copy(dst.Addr[:], ipv4)
end.ClearSrc()
return &end, nil
}
ipv6 := addr.IP.To16()
if ipv6 != nil {
zone, err := zoneToUint32(addr.Zone)
if err != nil {
return nil, err
}
dst := end.dst6()
end.isV6 = true
dst.Port = addr.Port
dst.ZoneId = zone
copy(dst.Addr[:], ipv6[:])
end.ClearSrc()
return &end, nil
}
return nil, errors.New("Invalid IP address")
}
func createNetlinkRouteSocket() (int, error) {
sock, err := unix.Socket(unix.AF_NETLINK, unix.SOCK_RAW, unix.NETLINK_ROUTE)
if err != nil {
return -1, err
}
saddr := &unix.SockaddrNetlink{
Family: unix.AF_NETLINK,
Groups: uint32(1 << (unix.RTNLGRP_IPV4_ROUTE - 1)),
}
err = unix.Bind(sock, saddr)
if err != nil {
unix.Close(sock)
return -1, err
}
return sock, nil
}
func CreateBind(port uint16, device *Device) (*NativeBind, uint16, error) {
var err error
var bind NativeBind
var newPort uint16
bind.netlinkSock, err = createNetlinkRouteSocket()
if err != nil {
return nil, 0, err
}
bind.netlinkCancel, err = rwcancel.NewRWCancel(bind.netlinkSock)
if err != nil {
unix.Close(bind.netlinkSock)
return nil, 0, err
}
go bind.routineRouteListener(device)
// attempt ipv6 bind, update port if succesful
bind.sock6, newPort, err = create6(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
return nil, 0, err
}
} else {
port = newPort
}
// attempt ipv4 bind, update port if succesful
bind.sock4, newPort, err = create4(port)
if err != nil {
if err != syscall.EAFNOSUPPORT {
bind.netlinkCancel.Cancel()
unix.Close(bind.sock6)
return nil, 0, err
}
} else {
port = newPort
}
if bind.sock4 == FD_ERR && bind.sock6 == FD_ERR {
return nil, 0, errors.New("ipv4 and ipv6 not supported")
}
return &bind, port, nil
}
func (bind *NativeBind) SetMark(value uint32) error {
if bind.sock6 != -1 {
err := unix.SetsockoptInt(
bind.sock6,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
if bind.sock4 != -1 {
err := unix.SetsockoptInt(
bind.sock4,
unix.SOL_SOCKET,
unix.SO_MARK,
int(value),
)
if err != nil {
return err
}
}
bind.lastMark = value
return nil
}
func closeUnblock(fd int) error {
// shutdown to unblock readers and writers
unix.Shutdown(fd, unix.SHUT_RDWR)
return unix.Close(fd)
}
func (bind *NativeBind) Close() error {
var err1, err2, err3 error
if bind.sock6 != -1 {
err1 = closeUnblock(bind.sock6)
}
if bind.sock4 != -1 {
err2 = closeUnblock(bind.sock4)
}
err3 = bind.netlinkCancel.Cancel()
if err1 != nil {
return err1
}
if err2 != nil {
return err2
}
return err3
}
func (bind *NativeBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
if bind.sock6 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err
}
func (bind *NativeBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
var end NativeEndpoint
if bind.sock4 == -1 {
return 0, nil, syscall.EAFNOSUPPORT
}
n, err := receive4(
bind.sock4,
buff,
&end,
)
return n, &end, err
}
func (bind *NativeBind) Send(buff []byte, end Endpoint) error {
nend := end.(*NativeEndpoint)
if !nend.isV6 {
if bind.sock4 == -1 {
return syscall.EAFNOSUPPORT
}
return send4(bind.sock4, nend, buff)
} else {
if bind.sock6 == -1 {
return syscall.EAFNOSUPPORT
}
return send6(bind.sock6, nend, buff)
}
}
func (end *NativeEndpoint) SrcIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.src4().src[0],
end.src4().src[1],
end.src4().src[2],
end.src4().src[3],
)
} else {
return end.src6().src[:]
}
}
func (end *NativeEndpoint) DstIP() net.IP {
if !end.isV6 {
return net.IPv4(
end.dst4().Addr[0],
end.dst4().Addr[1],
end.dst4().Addr[2],
end.dst4().Addr[3],
)
} else {
return end.dst6().Addr[:]
}
}
func (end *NativeEndpoint) DstToBytes() []byte {
if !end.isV6 {
return (*[unsafe.Offsetof(end.dst4().Addr) + unsafe.Sizeof(end.dst4().Addr)]byte)(unsafe.Pointer(end.dst4()))[:]
} else {
return (*[unsafe.Offsetof(end.dst6().Addr) + unsafe.Sizeof(end.dst6().Addr)]byte)(unsafe.Pointer(end.dst6()))[:]
}
}
func (end *NativeEndpoint) SrcToString() string {
return end.SrcIP().String()
}
func (end *NativeEndpoint) DstToString() string {
var udpAddr net.UDPAddr
udpAddr.IP = end.DstIP()
if !end.isV6 {
udpAddr.Port = end.dst4().Port
} else {
udpAddr.Port = end.dst6().Port
}
return udpAddr.String()
}
func (end *NativeEndpoint) ClearDst() {
for i := range end.dst {
end.dst[i] = 0
}
}
func (end *NativeEndpoint) ClearSrc() {
for i := range end.src {
end.src[i] = 0
}
}
func zoneToUint32(zone string) (uint32, error) {
if zone == "" {
return 0, nil
}
if intr, err := net.InterfaceByName(zone); err == nil {
return uint32(intr.Index), nil
}
n, err := strconv.ParseUint(zone, 10, 32)
return uint32(n), err
}
func create4(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return FD_ERR, 0, err
}
addr := unix.SockaddrInet4{
Port: int(port),
}
// set sockopts and bind
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IP,
unix.IP_PKTINFO,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return FD_ERR, 0, err
}
return fd, uint16(addr.Port), err
}
func create6(port uint16) (int, uint16, error) {
// create socket
fd, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return FD_ERR, 0, err
}
// set sockopts and bind
addr := unix.SockaddrInet6{
Port: int(port),
}
if err := func() error {
if err := unix.SetsockoptInt(
fd,
unix.SOL_SOCKET,
unix.SO_REUSEADDR,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_RECVPKTINFO,
1,
); err != nil {
return err
}
if err := unix.SetsockoptInt(
fd,
unix.IPPROTO_IPV6,
unix.IPV6_V6ONLY,
1,
); err != nil {
return err
}
return unix.Bind(fd, &addr)
}(); err != nil {
unix.Close(fd)
return FD_ERR, 0, err
}
return fd, uint16(addr.Port), err
}
func send4(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IP,
Type: unix.IP_PKTINFO,
Len: unix.SizeofInet4Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet4Pktinfo{
Spec_dst: end.src4().src,
Ifindex: end.src4().ifindex,
},
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet4Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst4(), 0)
}
return err
}
func send6(sock int, end *NativeEndpoint, buff []byte) error {
// construct message header
cmsg := struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}{
unix.Cmsghdr{
Level: unix.IPPROTO_IPV6,
Type: unix.IPV6_PKTINFO,
Len: unix.SizeofInet6Pktinfo + unix.SizeofCmsghdr,
},
unix.Inet6Pktinfo{
Addr: end.src6().src,
Ifindex: end.dst6().ZoneId,
},
}
if cmsg.pktinfo.Addr == [16]byte{} {
cmsg.pktinfo.Ifindex = 0
}
_, err := unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
if err == nil {
return nil
}
// clear src and retry
if err == unix.EINVAL {
end.ClearSrc()
cmsg.pktinfo = unix.Inet6Pktinfo{}
_, err = unix.SendmsgN(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], end.dst6(), 0)
}
return err
}
func receive4(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet4Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = false
if newDst4, ok := newDst.(*unix.SockaddrInet4); ok {
*end.dst4() = *newDst4
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IP &&
cmsg.cmsghdr.Type == unix.IP_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet4Pktinfo {
end.src4().src = cmsg.pktinfo.Spec_dst
end.src4().ifindex = cmsg.pktinfo.Ifindex
}
return size, nil
}
func receive6(sock int, buff []byte, end *NativeEndpoint) (int, error) {
// contruct message header
var cmsg struct {
cmsghdr unix.Cmsghdr
pktinfo unix.Inet6Pktinfo
}
size, _, _, newDst, err := unix.Recvmsg(sock, buff, (*[unsafe.Sizeof(cmsg)]byte)(unsafe.Pointer(&cmsg))[:], 0)
if err != nil {
return 0, err
}
end.isV6 = true
if newDst6, ok := newDst.(*unix.SockaddrInet6); ok {
*end.dst6() = *newDst6
}
// update source cache
if cmsg.cmsghdr.Level == unix.IPPROTO_IPV6 &&
cmsg.cmsghdr.Type == unix.IPV6_PKTINFO &&
cmsg.cmsghdr.Len >= unix.SizeofInet6Pktinfo {
end.src6().src = cmsg.pktinfo.Addr
end.dst6().ZoneId = cmsg.pktinfo.Ifindex
}
return size, nil
}
func (bind *NativeBind) routineRouteListener(device *Device) {
type peerEndpointPtr struct {
peer *Peer
endpoint *Endpoint
}
var reqPeer map[uint32]peerEndpointPtr
var reqPeerLock sync.Mutex
defer unix.Close(bind.netlinkSock)
for msg := make([]byte, 1<<16); ; {
var err error
var msgn int
for {
msgn, _, _, _, err = unix.Recvmsg(bind.netlinkSock, msg[:], nil, 0)
if err == nil || !rwcancel.RetryAfterError(err) {
break
}
if !bind.netlinkCancel.ReadyRead() {
return
}
}
if err != nil {
return
}
for remain := msg[:msgn]; len(remain) >= unix.SizeofNlMsghdr; {
hdr := *(*unix.NlMsghdr)(unsafe.Pointer(&remain[0]))
if uint(hdr.Len) > uint(len(remain)) {
break
}
switch hdr.Type {
case unix.RTM_NEWROUTE, unix.RTM_DELROUTE:
if hdr.Seq <= MaxPeers && hdr.Seq > 0 {
if uint(len(remain)) < uint(hdr.Len) {
break
}
if hdr.Len > unix.SizeofNlMsghdr+unix.SizeofRtMsg {
attr := remain[unix.SizeofNlMsghdr+unix.SizeofRtMsg:]
for {
if uint(len(attr)) < uint(unix.SizeofRtAttr) {
break
}
attrhdr := *(*unix.RtAttr)(unsafe.Pointer(&attr[0]))
if attrhdr.Len < unix.SizeofRtAttr || uint(len(attr)) < uint(attrhdr.Len) {
break
}
if attrhdr.Type == unix.RTA_OIF && attrhdr.Len == unix.SizeofRtAttr+4 {
ifidx := *(*uint32)(unsafe.Pointer(&attr[unix.SizeofRtAttr]))
reqPeerLock.Lock()
if reqPeer == nil {
reqPeerLock.Unlock()
break
}
pePtr, ok := reqPeer[hdr.Seq]
reqPeerLock.Unlock()
if !ok {
break
}
pePtr.peer.Lock()
if &pePtr.peer.endpoint != pePtr.endpoint {
pePtr.peer.Unlock()
break
}
if uint32(pePtr.peer.endpoint.(*NativeEndpoint).src4().ifindex) == ifidx {
pePtr.peer.Unlock()
break
}
pePtr.peer.endpoint.(*NativeEndpoint).ClearSrc()
pePtr.peer.Unlock()
}
attr = attr[attrhdr.Len:]
}
}
break
}
reqPeerLock.Lock()
reqPeer = make(map[uint32]peerEndpointPtr)
reqPeerLock.Unlock()
go func() {
device.peers.RLock()
i := uint32(1)
for _, peer := range device.peers.keyMap {
peer.RLock()
if peer.endpoint == nil || peer.endpoint.(*NativeEndpoint) == nil {
peer.RUnlock()
continue
}
if peer.endpoint.(*NativeEndpoint).isV6 || peer.endpoint.(*NativeEndpoint).src4().ifindex == 0 {
peer.RUnlock()
break
}
nlmsg := struct {
hdr unix.NlMsghdr
msg unix.RtMsg
dsthdr unix.RtAttr
dst [4]byte
srchdr unix.RtAttr
src [4]byte
markhdr unix.RtAttr
mark uint32
}{
unix.NlMsghdr{
Type: uint16(unix.RTM_GETROUTE),
Flags: unix.NLM_F_REQUEST,
Seq: i,
},
unix.RtMsg{
Family: unix.AF_INET,
Dst_len: 32,
Src_len: 32,
},
unix.RtAttr{
Len: 8,
Type: unix.RTA_DST,
},
peer.endpoint.(*NativeEndpoint).dst4().Addr,
unix.RtAttr{
Len: 8,
Type: unix.RTA_SRC,
},
peer.endpoint.(*NativeEndpoint).src4().src,
unix.RtAttr{
Len: 8,
Type: 0x10, //unix.RTA_MARK TODO: add this to x/sys/unix
},
uint32(bind.lastMark),
}
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
reqPeerLock.Lock()
reqPeer[i] = peerEndpointPtr{
peer: peer,
endpoint: &peer.endpoint,
}
reqPeerLock.Unlock()
peer.RUnlock()
i++
_, err := bind.netlinkCancel.Write((*[unsafe.Sizeof(nlmsg)]byte)(unsafe.Pointer(&nlmsg))[:])
if err != nil {
break
}
}
device.peers.RUnlock()
}()
}
remain = remain[hdr.Len:]
}
}
}

41
device/constants.go Normal file
View File

@@ -0,0 +1,41 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"time"
)
/* Specification constants */
const (
RekeyAfterMessages = (1 << 64) - (1 << 16) - 1
RejectAfterMessages = (1 << 64) - (1 << 4) - 1
RekeyAfterTime = time.Second * 120
RekeyAttemptTime = time.Second * 90
RekeyTimeout = time.Second * 5
MaxTimerHandshakes = 90 / 5 /* RekeyAttemptTime / RekeyTimeout */
RekeyTimeoutJitterMaxMs = 334
RejectAfterTime = time.Second * 180
KeepaliveTimeout = time.Second * 10
CookieRefreshTime = time.Second * 120
HandshakeInitationRate = time.Second / 20
PaddingMultiple = 16
)
const (
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
MaxMessageSize = MaxSegmentSize // maximum size of transport message
MaxContentSize = MaxSegmentSize - MessageTransportSize // maximum size of transport message content
)
/* Implementation constants */
const (
UnderLoadQueueSize = QueueHandshakeSize / 8
UnderLoadAfterTime = time.Second // how long does the device remain under load after detected
MaxPeers = 1 << 16 // maximum number of configured peers
)

250
device/cookie.go Normal file
View File

@@ -0,0 +1,250 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/hmac"
"crypto/rand"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"sync"
"time"
)
type CookieChecker struct {
sync.RWMutex
mac1 struct {
key [blake2s.Size]byte
}
mac2 struct {
secret [blake2s.Size]byte
secretSet time.Time
encryptionKey [chacha20poly1305.KeySize]byte
}
}
type CookieGenerator struct {
sync.RWMutex
mac1 struct {
key [blake2s.Size]byte
}
mac2 struct {
cookie [blake2s.Size128]byte
cookieSet time.Time
hasLastMAC1 bool
lastMAC1 [blake2s.Size128]byte
encryptionKey [chacha20poly1305.KeySize]byte
}
}
func (st *CookieChecker) Init(pk NoisePublicKey) {
st.Lock()
defer st.Unlock()
// mac1 state
func() {
hash, _ := blake2s.New256(nil)
hash.Write([]byte(WGLabelMAC1))
hash.Write(pk[:])
hash.Sum(st.mac1.key[:0])
}()
// mac2 state
func() {
hash, _ := blake2s.New256(nil)
hash.Write([]byte(WGLabelCookie))
hash.Write(pk[:])
hash.Sum(st.mac2.encryptionKey[:0])
}()
st.mac2.secretSet = time.Time{}
}
func (st *CookieChecker) CheckMAC1(msg []byte) bool {
st.RLock()
defer st.RUnlock()
size := len(msg)
smac2 := size - blake2s.Size128
smac1 := smac2 - blake2s.Size128
var mac1 [blake2s.Size128]byte
mac, _ := blake2s.New128(st.mac1.key[:])
mac.Write(msg[:smac1])
mac.Sum(mac1[:0])
return hmac.Equal(mac1[:], msg[smac1:smac2])
}
func (st *CookieChecker) CheckMAC2(msg []byte, src []byte) bool {
st.RLock()
defer st.RUnlock()
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
return false
}
// derive cookie key
var cookie [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src)
mac.Sum(cookie[:0])
}()
// calculate mac of packet (including mac1)
smac2 := len(msg) - blake2s.Size128
var mac2 [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(cookie[:])
mac.Write(msg[:smac2])
mac.Sum(mac2[:0])
}()
return hmac.Equal(mac2[:], msg[smac2:])
}
func (st *CookieChecker) CreateReply(
msg []byte,
recv uint32,
src []byte,
) (*MessageCookieReply, error) {
st.RLock()
// refresh cookie secret
if time.Now().Sub(st.mac2.secretSet) > CookieRefreshTime {
st.RUnlock()
st.Lock()
_, err := rand.Read(st.mac2.secret[:])
if err != nil {
st.Unlock()
return nil, err
}
st.mac2.secretSet = time.Now()
st.Unlock()
st.RLock()
}
// derive cookie
var cookie [blake2s.Size128]byte
func() {
mac, _ := blake2s.New128(st.mac2.secret[:])
mac.Write(src)
mac.Sum(cookie[:0])
}()
// encrypt cookie
size := len(msg)
smac2 := size - blake2s.Size128
smac1 := smac2 - blake2s.Size128
reply := new(MessageCookieReply)
reply.Type = MessageCookieReplyType
reply.Receiver = recv
_, err := rand.Read(reply.Nonce[:])
if err != nil {
st.RUnlock()
return nil, err
}
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
xchapoly.Seal(reply.Cookie[:0], reply.Nonce[:], cookie[:], msg[smac1:smac2])
st.RUnlock()
return reply, nil
}
func (st *CookieGenerator) Init(pk NoisePublicKey) {
st.Lock()
defer st.Unlock()
func() {
hash, _ := blake2s.New256(nil)
hash.Write([]byte(WGLabelMAC1))
hash.Write(pk[:])
hash.Sum(st.mac1.key[:0])
}()
func() {
hash, _ := blake2s.New256(nil)
hash.Write([]byte(WGLabelCookie))
hash.Write(pk[:])
hash.Sum(st.mac2.encryptionKey[:0])
}()
st.mac2.cookieSet = time.Time{}
}
func (st *CookieGenerator) ConsumeReply(msg *MessageCookieReply) bool {
st.Lock()
defer st.Unlock()
if !st.mac2.hasLastMAC1 {
return false
}
var cookie [blake2s.Size128]byte
xchapoly, _ := chacha20poly1305.NewX(st.mac2.encryptionKey[:])
_, err := xchapoly.Open(cookie[:0], msg.Nonce[:], msg.Cookie[:], st.mac2.lastMAC1[:])
if err != nil {
return false
}
st.mac2.cookieSet = time.Now()
st.mac2.cookie = cookie
return true
}
func (st *CookieGenerator) AddMacs(msg []byte) {
size := len(msg)
smac2 := size - blake2s.Size128
smac1 := smac2 - blake2s.Size128
mac1 := msg[smac1:smac2]
mac2 := msg[smac2:]
st.Lock()
defer st.Unlock()
// set mac1
func() {
mac, _ := blake2s.New128(st.mac1.key[:])
mac.Write(msg[:smac1])
mac.Sum(mac1[:0])
}()
copy(st.mac2.lastMAC1[:], mac1)
st.mac2.hasLastMAC1 = true
// set mac2
if time.Now().Sub(st.mac2.cookieSet) > CookieRefreshTime {
return
}
func() {
mac, _ := blake2s.New128(st.mac2.cookie[:])
mac.Write(msg[:smac2])
mac.Sum(mac2[:0])
}()
}

191
device/cookie_test.go Normal file
View File

@@ -0,0 +1,191 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"testing"
)
func TestCookieMAC1(t *testing.T) {
// setup generator / checker
var (
generator CookieGenerator
checker CookieChecker
)
sk, err := newPrivateKey()
if err != nil {
t.Fatal(err)
}
pk := sk.publicKey()
generator.Init(pk)
checker.Init(pk)
// check mac1
src := []byte{192, 168, 13, 37, 10, 10, 10}
checkMAC1 := func(msg []byte) {
generator.AddMacs(msg)
if !checker.CheckMAC1(msg) {
t.Fatal("MAC1 generation/verification failed")
}
if checker.CheckMAC2(msg, src) {
t.Fatal("MAC2 generation/verification failed")
}
}
checkMAC1([]byte{
0x99, 0xbb, 0xa5, 0xfc, 0x99, 0xaa, 0x83, 0xbd,
0x7b, 0x00, 0xc5, 0x9a, 0x4c, 0xb9, 0xcf, 0x62,
0x40, 0x23, 0xf3, 0x8e, 0xd8, 0xd0, 0x62, 0x64,
0x5d, 0xb2, 0x80, 0x13, 0xda, 0xce, 0xc6, 0x91,
0x61, 0xd6, 0x30, 0xf1, 0x32, 0xb3, 0xa2, 0xf4,
0x7b, 0x43, 0xb5, 0xa7, 0xe2, 0xb1, 0xf5, 0x6c,
0x74, 0x6b, 0xb0, 0xcd, 0x1f, 0x94, 0x86, 0x7b,
0xc8, 0xfb, 0x92, 0xed, 0x54, 0x9b, 0x44, 0xf5,
0xc8, 0x7d, 0xb7, 0x8e, 0xff, 0x49, 0xc4, 0xe8,
0x39, 0x7c, 0x19, 0xe0, 0x60, 0x19, 0x51, 0xf8,
0xe4, 0x8e, 0x02, 0xf1, 0x7f, 0x1d, 0xcc, 0x8e,
0xb0, 0x07, 0xff, 0xf8, 0xaf, 0x7f, 0x66, 0x82,
0x83, 0xcc, 0x7c, 0xfa, 0x80, 0xdb, 0x81, 0x53,
0xad, 0xf7, 0xd8, 0x0c, 0x10, 0xe0, 0x20, 0xfd,
0xe8, 0x0b, 0x3f, 0x90, 0x15, 0xcd, 0x93, 0xad,
0x0b, 0xd5, 0x0c, 0xcc, 0x88, 0x56, 0xe4, 0x3f,
})
checkMAC1([]byte{
0x33, 0xe7, 0x2a, 0x84, 0x9f, 0xff, 0x57, 0x6c,
0x2d, 0xc3, 0x2d, 0xe1, 0xf5, 0x5c, 0x97, 0x56,
0xb8, 0x93, 0xc2, 0x7d, 0xd4, 0x41, 0xdd, 0x7a,
0x4a, 0x59, 0x3b, 0x50, 0xdd, 0x7a, 0x7a, 0x8c,
0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b,
0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc,
0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b,
0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05,
})
checkMAC1([]byte{
0x9b, 0x96, 0xaf, 0x55, 0x3c, 0xeb, 0x6d, 0x0b,
0x13, 0x0b, 0x97, 0x98, 0xb3, 0x40, 0xc3, 0xcc,
0xb8, 0x57, 0x33, 0x45, 0x6e, 0x8b, 0x09, 0x2b,
0x81, 0x2e, 0xd2, 0xb9, 0x66, 0x0b, 0x93, 0x05,
})
// exchange cookie reply
func() {
msg := []byte{
0x6d, 0xd7, 0xc3, 0x2e, 0xb0, 0x76, 0xd8, 0xdf,
0x30, 0x65, 0x7d, 0x62, 0x3e, 0xf8, 0x9a, 0xe8,
0xe7, 0x3c, 0x64, 0xa3, 0x78, 0x48, 0xda, 0xf5,
0x25, 0x61, 0x28, 0x53, 0x79, 0x32, 0x86, 0x9f,
0xa0, 0x27, 0x95, 0x69, 0xb6, 0xba, 0xd0, 0xa2,
0xf8, 0x68, 0xea, 0xa8, 0x62, 0xf2, 0xfd, 0x1b,
0xe0, 0xb4, 0x80, 0xe5, 0x6b, 0x3a, 0x16, 0x9e,
0x35, 0xf6, 0xa8, 0xf2, 0x4f, 0x9a, 0x7b, 0xe9,
0x77, 0x0b, 0xc2, 0xb4, 0xed, 0xba, 0xf9, 0x22,
0xc3, 0x03, 0x97, 0x42, 0x9f, 0x79, 0x74, 0x27,
0xfe, 0xf9, 0x06, 0x6e, 0x97, 0x3a, 0xa6, 0x8f,
0xc9, 0x57, 0x0a, 0x54, 0x4c, 0x64, 0x4a, 0xe2,
0x4f, 0xa1, 0xce, 0x95, 0x9b, 0x23, 0xa9, 0x2b,
0x85, 0x93, 0x42, 0xb0, 0xa5, 0x53, 0xed, 0xeb,
0x63, 0x2a, 0xf1, 0x6d, 0x46, 0xcb, 0x2f, 0x61,
0x8c, 0xe1, 0xe8, 0xfa, 0x67, 0x20, 0x80, 0x6d,
}
generator.AddMacs(msg)
reply, err := checker.CreateReply(msg, 1377, src)
if err != nil {
t.Fatal("Failed to create cookie reply:", err)
}
if !generator.ConsumeReply(reply) {
t.Fatal("Failed to consume cookie reply")
}
}()
// check mac2
checkMAC2 := func(msg []byte) {
generator.AddMacs(msg)
if !checker.CheckMAC1(msg) {
t.Fatal("MAC1 generation/verification failed")
}
if !checker.CheckMAC2(msg, src) {
t.Fatal("MAC2 generation/verification failed")
}
msg[5] ^= 0x20
if checker.CheckMAC1(msg) {
t.Fatal("MAC1 generation/verification failed")
}
if checker.CheckMAC2(msg, src) {
t.Fatal("MAC2 generation/verification failed")
}
msg[5] ^= 0x20
srcBad1 := []byte{192, 168, 13, 37, 40, 01}
if checker.CheckMAC2(msg, srcBad1) {
t.Fatal("MAC2 generation/verification failed")
}
srcBad2 := []byte{192, 168, 13, 38, 40, 01}
if checker.CheckMAC2(msg, srcBad2) {
t.Fatal("MAC2 generation/verification failed")
}
}
checkMAC2([]byte{
0x03, 0x31, 0xb9, 0x9e, 0xb0, 0x2a, 0x54, 0xa3,
0xc1, 0x3f, 0xb4, 0x96, 0x16, 0xb9, 0x25, 0x15,
0x3d, 0x3a, 0x82, 0xf9, 0x58, 0x36, 0x86, 0x3f,
0x13, 0x2f, 0xfe, 0xb2, 0x53, 0x20, 0x8c, 0x3f,
0xba, 0xeb, 0xfb, 0x4b, 0x1b, 0x22, 0x02, 0x69,
0x2c, 0x90, 0xbc, 0xdc, 0xcf, 0xcf, 0x85, 0xeb,
0x62, 0x66, 0x6f, 0xe8, 0xe1, 0xa6, 0xa8, 0x4c,
0xa0, 0x04, 0x23, 0x15, 0x42, 0xac, 0xfa, 0x38,
})
checkMAC2([]byte{
0x0e, 0x2f, 0x0e, 0xa9, 0x29, 0x03, 0xe1, 0xf3,
0x24, 0x01, 0x75, 0xad, 0x16, 0xa5, 0x66, 0x85,
0xca, 0x66, 0xe0, 0xbd, 0xc6, 0x34, 0xd8, 0x84,
0x09, 0x9a, 0x58, 0x14, 0xfb, 0x05, 0xda, 0xf5,
0x90, 0xf5, 0x0c, 0x4e, 0x22, 0x10, 0xc9, 0x85,
0x0f, 0xe3, 0x77, 0x35, 0xe9, 0x6b, 0xc2, 0x55,
0x32, 0x46, 0xae, 0x25, 0xe0, 0xe3, 0x37, 0x7a,
0x4b, 0x71, 0xcc, 0xfc, 0x91, 0xdf, 0xd6, 0xca,
0xfe, 0xee, 0xce, 0x3f, 0x77, 0xa2, 0xfd, 0x59,
0x8e, 0x73, 0x0a, 0x8d, 0x5c, 0x24, 0x14, 0xca,
0x38, 0x91, 0xb8, 0x2c, 0x8c, 0xa2, 0x65, 0x7b,
0xbc, 0x49, 0xbc, 0xb5, 0x58, 0xfc, 0xe3, 0xd7,
0x02, 0xcf, 0xf7, 0x4c, 0x60, 0x91, 0xed, 0x55,
0xe9, 0xf9, 0xfe, 0xd1, 0x44, 0x2c, 0x75, 0xf2,
0xb3, 0x5d, 0x7b, 0x27, 0x56, 0xc0, 0x48, 0x4f,
0xb0, 0xba, 0xe4, 0x7d, 0xd0, 0xaa, 0xcd, 0x3d,
0xe3, 0x50, 0xd2, 0xcf, 0xb9, 0xfa, 0x4b, 0x2d,
0xc6, 0xdf, 0x3b, 0x32, 0x98, 0x45, 0xe6, 0x8f,
0x1c, 0x5c, 0xa2, 0x20, 0x7d, 0x1c, 0x28, 0xc2,
0xd4, 0xa1, 0xe0, 0x21, 0x52, 0x8f, 0x1c, 0xd0,
0x62, 0x97, 0x48, 0xbb, 0xf4, 0xa9, 0xcb, 0x35,
0xf2, 0x07, 0xd3, 0x50, 0xd8, 0xa9, 0xc5, 0x9a,
0x0f, 0xbd, 0x37, 0xaf, 0xe1, 0x45, 0x19, 0xee,
0x41, 0xf3, 0xf7, 0xe5, 0xe0, 0x30, 0x3f, 0xbe,
0x3d, 0x39, 0x64, 0x00, 0x7a, 0x1a, 0x51, 0x5e,
0xe1, 0x70, 0x0b, 0xb9, 0x77, 0x5a, 0xf0, 0xc4,
0x8a, 0xa1, 0x3a, 0x77, 0x1a, 0xe0, 0xc2, 0x06,
0x91, 0xd5, 0xe9, 0x1c, 0xd3, 0xfe, 0xab, 0x93,
0x1a, 0x0a, 0x4c, 0xbb, 0xf0, 0xff, 0xdc, 0xaa,
0x61, 0x73, 0xcb, 0x03, 0x4b, 0x71, 0x68, 0x64,
0x3d, 0x82, 0x31, 0x41, 0xd7, 0x8b, 0x22, 0x7b,
0x7d, 0xa1, 0xd5, 0x85, 0x6d, 0xf0, 0x1b, 0xaa,
})
}

396
device/device.go Normal file
View File

@@ -0,0 +1,396 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/tun"
"runtime"
"sync"
"sync/atomic"
"time"
)
const (
DeviceRoutineNumberPerCPU = 3
DeviceRoutineNumberAdditional = 2
)
type Device struct {
isUp AtomicBool // device is (going) up
isClosed AtomicBool // device is closed? (acting as guard)
log *Logger
// synchronized resources (locks acquired in order)
state struct {
starting sync.WaitGroup
stopping sync.WaitGroup
sync.Mutex
changing AtomicBool
current bool
}
net struct {
starting sync.WaitGroup
stopping sync.WaitGroup
sync.RWMutex
bind Bind // bind interface
port uint16 // listening port
fwmark uint32 // mark value (0 = disabled)
}
staticIdentity struct {
sync.RWMutex
privateKey NoisePrivateKey
publicKey NoisePublicKey
}
peers struct {
sync.RWMutex
keyMap map[NoisePublicKey]*Peer
}
// unprotected / "self-synchronising resources"
allowedips AllowedIPs
indexTable IndexTable
cookieChecker CookieChecker
rate struct {
underLoadUntil atomic.Value
limiter ratelimiter.Ratelimiter
}
pool struct {
messageBufferPool *sync.Pool
messageBufferReuseChan chan *[MaxMessageSize]byte
inboundElementPool *sync.Pool
inboundElementReuseChan chan *QueueInboundElement
outboundElementPool *sync.Pool
outboundElementReuseChan chan *QueueOutboundElement
}
queue struct {
encryption chan *QueueOutboundElement
decryption chan *QueueInboundElement
handshake chan QueueHandshakeElement
}
signals struct {
stop chan struct{}
}
tun struct {
device tun.TUNDevice
mtu int32
}
}
/* Converts the peer into a "zombie", which remains in the peer map,
* but processes no packets and does not exists in the routing table.
*
* Must hold device.peers.Mutex
*/
func unsafeRemovePeer(device *Device, peer *Peer, key NoisePublicKey) {
// stop routing and processing of packets
device.allowedips.RemoveByPeer(peer)
peer.Stop()
// remove from peer map
delete(device.peers.keyMap, key)
}
func deviceUpdateState(device *Device) {
// check if state already being updated (guard)
if device.state.changing.Swap(true) {
return
}
// compare to current state of device
device.state.Lock()
newIsUp := device.isUp.Get()
if newIsUp == device.state.current {
device.state.changing.Set(false)
device.state.Unlock()
return
}
// change state of device
switch newIsUp {
case true:
if err := device.BindUpdate(); err != nil {
device.isUp.Set(false)
break
}
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Start()
if peer.persistentKeepaliveInterval > 0 {
peer.SendKeepalive()
}
}
device.peers.RUnlock()
case false:
device.BindClose()
device.peers.RLock()
for _, peer := range device.peers.keyMap {
peer.Stop()
}
device.peers.RUnlock()
}
// update state variables
device.state.current = newIsUp
device.state.changing.Set(false)
device.state.Unlock()
// check for state change in the mean time
deviceUpdateState(device)
}
func (device *Device) Up() {
// closed device cannot be brought up
if device.isClosed.Get() {
return
}
device.isUp.Set(true)
deviceUpdateState(device)
}
func (device *Device) Down() {
device.isUp.Set(false)
deviceUpdateState(device)
}
func (device *Device) IsUnderLoad() bool {
// check if currently under load
now := time.Now()
underLoad := len(device.queue.handshake) >= UnderLoadQueueSize
if underLoad {
device.rate.underLoadUntil.Store(now.Add(UnderLoadAfterTime))
return true
}
// check if recently under load
until := device.rate.underLoadUntil.Load().(time.Time)
return until.After(now)
}
func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
// lock required resources
device.staticIdentity.Lock()
defer device.staticIdentity.Unlock()
device.peers.Lock()
defer device.peers.Unlock()
for _, peer := range device.peers.keyMap {
peer.handshake.mutex.RLock()
defer peer.handshake.mutex.RUnlock()
}
// remove peers with matching public keys
publicKey := sk.publicKey()
for key, peer := range device.peers.keyMap {
if peer.handshake.remoteStatic.Equals(publicKey) {
unsafeRemovePeer(device, peer, key)
}
}
// update key material
device.staticIdentity.privateKey = sk
device.staticIdentity.publicKey = publicKey
device.cookieChecker.Init(publicKey)
// do static-static DH pre-computations
rmKey := device.staticIdentity.privateKey.IsZero()
for key, peer := range device.peers.keyMap {
handshake := &peer.handshake
if rmKey {
handshake.precomputedStaticStatic = [NoisePublicKeySize]byte{}
} else {
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(handshake.remoteStatic)
}
if isZero(handshake.precomputedStaticStatic[:]) {
unsafeRemovePeer(device, peer, key)
}
}
return nil
}
func NewDevice(tunDevice tun.TUNDevice, logger *Logger) *Device {
device := new(Device)
device.isUp.Set(false)
device.isClosed.Set(false)
device.log = logger
device.tun.device = tunDevice
mtu, err := device.tun.device.MTU()
if err != nil {
logger.Error.Println("Trouble determining MTU, assuming default:", err)
mtu = DefaultMTU
}
device.tun.mtu = int32(mtu)
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
device.rate.limiter.Init()
device.rate.underLoadUntil.Store(time.Time{})
device.indexTable.Init()
device.allowedips.Reset()
device.PopulatePools()
// create queues
device.queue.handshake = make(chan QueueHandshakeElement, QueueHandshakeSize)
device.queue.encryption = make(chan *QueueOutboundElement, QueueOutboundSize)
device.queue.decryption = make(chan *QueueInboundElement, QueueInboundSize)
// prepare signals
device.signals.stop = make(chan struct{})
// prepare net
device.net.port = 0
device.net.bind = nil
// start workers
cpus := runtime.NumCPU()
device.state.starting.Wait()
device.state.stopping.Wait()
device.state.stopping.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
device.state.starting.Add(DeviceRoutineNumberPerCPU*cpus + DeviceRoutineNumberAdditional)
for i := 0; i < cpus; i += 1 {
go device.RoutineEncryption()
go device.RoutineDecryption()
go device.RoutineHandshake()
}
go device.RoutineReadFromTUN()
go device.RoutineTUNEventReader()
device.state.starting.Wait()
return device
}
func (device *Device) LookupPeer(pk NoisePublicKey) *Peer {
device.peers.RLock()
defer device.peers.RUnlock()
return device.peers.keyMap[pk]
}
func (device *Device) RemovePeer(key NoisePublicKey) {
device.peers.Lock()
defer device.peers.Unlock()
// stop peer and remove from routing
peer, ok := device.peers.keyMap[key]
if ok {
unsafeRemovePeer(device, peer, key)
}
}
func (device *Device) RemoveAllPeers() {
device.peers.Lock()
defer device.peers.Unlock()
for key, peer := range device.peers.keyMap {
unsafeRemovePeer(device, peer, key)
}
device.peers.keyMap = make(map[NoisePublicKey]*Peer)
}
func (device *Device) FlushPacketQueues() {
for {
select {
case elem, ok := <-device.queue.decryption:
if ok {
elem.Drop()
}
case elem, ok := <-device.queue.encryption:
if ok {
elem.Drop()
}
case <-device.queue.handshake:
default:
return
}
}
}
func (device *Device) Close() {
if device.isClosed.Swap(true) {
return
}
device.state.starting.Wait()
device.log.Info.Println("Device closing")
device.state.changing.Set(true)
device.state.Lock()
defer device.state.Unlock()
device.tun.device.Close()
device.BindClose()
device.isUp.Set(false)
close(device.signals.stop)
device.RemoveAllPeers()
device.state.stopping.Wait()
device.FlushPacketQueues()
device.rate.limiter.Close()
device.state.changing.Set(false)
device.log.Info.Println("Interface closed")
}
func (device *Device) Wait() chan struct{} {
return device.signals.stop
}

48
device/device_test.go Normal file
View File

@@ -0,0 +1,48 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
/* Create two device instances and simulate full WireGuard interaction
* without network dependencies
*/
import "testing"
func TestDevice(t *testing.T) {
// prepare tun devices for generating traffic
tun1, err := CreateDummyTUN("tun1")
if err != nil {
t.Error("failed to create tun:", err.Error())
}
tun2, err := CreateDummyTUN("tun2")
if err != nil {
t.Error("failed to create tun:", err.Error())
}
_ = tun1
_ = tun2
// prepare endpoints
end1, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
end2, err := CreateDummyEndpoint()
if err != nil {
t.Error("failed to create endpoint:", err.Error())
}
_ = end1
_ = end2
// create binds
}

53
device/endpoint_test.go Normal file
View File

@@ -0,0 +1,53 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"math/rand"
"net"
)
type DummyEndpoint struct {
src [16]byte
dst [16]byte
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint
if _, err := rand.Read(end.src[:]); err != nil {
return nil, err
}
_, err := rand.Read(end.dst[:])
return &end, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
}
func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
}
func (e *DummyEndpoint) SrcToBytes() []byte {
return e.src[:]
}
func (e *DummyEndpoint) DstIP() net.IP {
return e.dst[:]
}
func (e *DummyEndpoint) SrcIP() net.IP {
return e.src[:]
}

97
device/indextable.go Normal file
View File

@@ -0,0 +1,97 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/rand"
"sync"
"unsafe"
)
type IndexTableEntry struct {
peer *Peer
handshake *Handshake
keypair *Keypair
}
type IndexTable struct {
sync.RWMutex
table map[uint32]IndexTableEntry
}
func randUint32() (uint32, error) {
var integer [4]byte
_, err := rand.Read(integer[:])
return *(*uint32)(unsafe.Pointer(&integer[0])), err
}
func (table *IndexTable) Init() {
table.Lock()
defer table.Unlock()
table.table = make(map[uint32]IndexTableEntry)
}
func (table *IndexTable) Delete(index uint32) {
table.Lock()
defer table.Unlock()
delete(table.table, index)
}
func (table *IndexTable) SwapIndexForKeypair(index uint32, keypair *Keypair) {
table.Lock()
defer table.Unlock()
entry, ok := table.table[index]
if !ok {
return
}
table.table[index] = IndexTableEntry{
peer: entry.peer,
keypair: keypair,
handshake: nil,
}
}
func (table *IndexTable) NewIndexForHandshake(peer *Peer, handshake *Handshake) (uint32, error) {
for {
// generate random index
index, err := randUint32()
if err != nil {
return index, err
}
// check if index used
table.RLock()
_, ok := table.table[index]
table.RUnlock()
if ok {
continue
}
// check again while locked
table.Lock()
_, found := table.table[index]
if found {
table.Unlock()
continue
}
table.table[index] = IndexTableEntry{
peer: peer,
handshake: handshake,
keypair: nil,
}
table.Unlock()
return index, nil
}
}
func (table *IndexTable) Lookup(id uint32) IndexTableEntry {
table.RLock()
defer table.RUnlock()
return table.table[id]
}

22
device/ip.go Normal file
View File

@@ -0,0 +1,22 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"net"
)
const (
IPv4offsetTotalLength = 2
IPv4offsetSrc = 12
IPv4offsetDst = IPv4offsetSrc + net.IPv4len
)
const (
IPv6offsetPayloadLength = 4
IPv6offsetSrc = 8
IPv6offsetDst = IPv6offsetSrc + net.IPv6len
)

84
device/kdf_test.go Normal file
View File

@@ -0,0 +1,84 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"encoding/hex"
"golang.org/x/crypto/blake2s"
"testing"
)
type KDFTest struct {
key string
input string
t0 string
t1 string
t2 string
}
func assertEquals(t *testing.T, a string, b string) {
if a != b {
t.Fatal("expected", a, "=", b)
}
}
func TestKDF(t *testing.T) {
tests := []KDFTest{
{
key: "746573742d6b6579",
input: "746573742d696e707574",
t0: "6f0e5ad38daba1bea8a0d213688736f19763239305e0f58aba697f9ffc41c633",
t1: "df1194df20802a4fe594cde27e92991c8cae66c366e8106aaa937a55fa371e8a",
t2: "fac6e2745a325f5dc5d11a5b165aad08b0ada28e7b4e666b7c077934a4d76c24",
},
{
key: "776972656775617264",
input: "776972656775617264",
t0: "491d43bbfdaa8750aaf535e334ecbfe5129967cd64635101c566d4caefda96e8",
t1: "1e71a379baefd8a79aa4662212fcafe19a23e2b609a3db7d6bcba8f560e3d25f",
t2: "31e1ae48bddfbe5de38f295e5452b1909a1b4e38e183926af3780b0c1e1f0160",
},
{
key: "",
input: "",
t0: "8387b46bf43eccfcf349552a095d8315c4055beb90208fb1be23b894bc2ed5d0",
t1: "58a0e5f6faefccf4807bff1f05fa8a9217945762040bcec2f4b4a62bdfe0e86e",
t2: "0ce6ea98ec548f8e281e93e32db65621c45eb18dc6f0a7ad94178610a2f7338e",
},
}
var t0, t1, t2 [blake2s.Size]byte
for _, test := range tests {
key, _ := hex.DecodeString(test.key)
input, _ := hex.DecodeString(test.input)
KDF3(&t0, &t1, &t2, key, input)
t0s := hex.EncodeToString(t0[:])
t1s := hex.EncodeToString(t1[:])
t2s := hex.EncodeToString(t2[:])
assertEquals(t, t0s, test.t0)
assertEquals(t, t1s, test.t1)
assertEquals(t, t2s, test.t2)
}
for _, test := range tests {
key, _ := hex.DecodeString(test.key)
input, _ := hex.DecodeString(test.input)
KDF2(&t0, &t1, key, input)
t0s := hex.EncodeToString(t0[:])
t1s := hex.EncodeToString(t1[:])
assertEquals(t, t0s, test.t0)
assertEquals(t, t1s, test.t1)
}
for _, test := range tests {
key, _ := hex.DecodeString(test.key)
input, _ := hex.DecodeString(test.input)
KDF1(&t0, key, input)
t0s := hex.EncodeToString(t0[:])
assertEquals(t, t0s, test.t0)
}
}

50
device/keypair.go Normal file
View File

@@ -0,0 +1,50 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/cipher"
"golang.zx2c4.com/wireguard/replay"
"sync"
"time"
)
/* Due to limitations in Go and /x/crypto there is currently
* no way to ensure that key material is securely ereased in memory.
*
* Since this may harm the forward secrecy property,
* we plan to resolve this issue; whenever Go allows us to do so.
*/
type Keypair struct {
sendNonce uint64
send cipher.AEAD
receive cipher.AEAD
replayFilter replay.ReplayFilter
isInitiator bool
created time.Time
localIndex uint32
remoteIndex uint32
}
type Keypairs struct {
sync.RWMutex
current *Keypair
previous *Keypair
next *Keypair
}
func (kp *Keypairs) Current() *Keypair {
kp.RLock()
defer kp.RUnlock()
return kp.current
}
func (device *Device) DeleteKeypair(key *Keypair) {
if key != nil {
device.indexTable.Delete(key.localIndex)
}
}

59
device/logger.go Normal file
View File

@@ -0,0 +1,59 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"io"
"io/ioutil"
"log"
"os"
)
const (
LogLevelSilent = iota
LogLevelError
LogLevelInfo
LogLevelDebug
)
type Logger struct {
Debug *log.Logger
Info *log.Logger
Error *log.Logger
}
func NewLogger(level int, prepend string) *Logger {
output := os.Stdout
logger := new(Logger)
logErr, logInfo, logDebug := func() (io.Writer, io.Writer, io.Writer) {
if level >= LogLevelDebug {
return output, output, output
}
if level >= LogLevelInfo {
return output, output, ioutil.Discard
}
if level >= LogLevelError {
return output, ioutil.Discard, ioutil.Discard
}
return ioutil.Discard, ioutil.Discard, ioutil.Discard
}()
logger.Debug = log.New(logDebug,
"DEBUG: "+prepend,
log.Ldate|log.Ltime,
)
logger.Info = log.New(logInfo,
"INFO: "+prepend,
log.Ldate|log.Ltime,
)
logger.Error = log.New(logErr,
"ERROR: "+prepend,
log.Ldate|log.Ltime,
)
return logger
}

12
device/mark_default.go Normal file
View File

@@ -0,0 +1,12 @@
// +build !linux,!openbsd,!freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
func (bind *NativeBind) SetMark(mark uint32) error {
return nil
}

64
device/mark_unix.go Normal file
View File

@@ -0,0 +1,64 @@
// +build android openbsd freebsd
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"golang.org/x/sys/unix"
"runtime"
)
var fwmarkIoctl int
func init() {
switch runtime.GOOS {
case "linux", "android":
fwmarkIoctl = 36 /* unix.SO_MARK */
case "freebsd":
fwmarkIoctl = 0x1015 /* unix.SO_USER_COOKIE */
case "openbsd":
fwmarkIoctl = 0x1021 /* unix.SO_RTABLE */
}
}
func (bind *NativeBind) SetMark(mark uint32) error {
var operr error
if fwmarkIoctl == 0 {
return nil
}
if bind.ipv4 != nil {
fd, err := bind.ipv4.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
if bind.ipv6 != nil {
fd, err := bind.ipv6.SyscallConn()
if err != nil {
return err
}
err = fd.Control(func(fd uintptr) {
operr = unix.SetsockoptInt(int(fd), unix.SOL_SOCKET, fwmarkIoctl, int(mark))
})
if err == nil {
err = operr
}
if err != nil {
return err
}
}
return nil
}

48
device/misc.go Normal file
View File

@@ -0,0 +1,48 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"sync/atomic"
)
/* Atomic Boolean */
const (
AtomicFalse = int32(iota)
AtomicTrue
)
type AtomicBool struct {
int32
}
func (a *AtomicBool) Get() bool {
return atomic.LoadInt32(&a.int32) == AtomicTrue
}
func (a *AtomicBool) Swap(val bool) bool {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
return atomic.SwapInt32(&a.int32, flag) == AtomicTrue
}
func (a *AtomicBool) Set(val bool) {
flag := AtomicFalse
if val {
flag = AtomicTrue
}
atomic.StoreInt32(&a.int32, flag)
}
func min(a, b uint) uint {
if a > b {
return b
}
return a
}

104
device/noise-helpers.go Normal file
View File

@@ -0,0 +1,104 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/hmac"
"crypto/rand"
"crypto/subtle"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/curve25519"
"hash"
)
/* KDF related functions.
* HMAC-based Key Derivation Function (HKDF)
* https://tools.ietf.org/html/rfc5869
*/
func HMAC1(sum *[blake2s.Size]byte, key, in0 []byte) {
mac := hmac.New(func() hash.Hash {
h, _ := blake2s.New256(nil)
return h
}, key)
mac.Write(in0)
mac.Sum(sum[:0])
}
func HMAC2(sum *[blake2s.Size]byte, key, in0, in1 []byte) {
mac := hmac.New(func() hash.Hash {
h, _ := blake2s.New256(nil)
return h
}, key)
mac.Write(in0)
mac.Write(in1)
mac.Sum(sum[:0])
}
func KDF1(t0 *[blake2s.Size]byte, key, input []byte) {
HMAC1(t0, key, input)
HMAC1(t0, t0[:], []byte{0x1})
return
}
func KDF2(t0, t1 *[blake2s.Size]byte, key, input []byte) {
var prk [blake2s.Size]byte
HMAC1(&prk, key, input)
HMAC1(t0, prk[:], []byte{0x1})
HMAC2(t1, prk[:], t0[:], []byte{0x2})
setZero(prk[:])
return
}
func KDF3(t0, t1, t2 *[blake2s.Size]byte, key, input []byte) {
var prk [blake2s.Size]byte
HMAC1(&prk, key, input)
HMAC1(t0, prk[:], []byte{0x1})
HMAC2(t1, prk[:], t0[:], []byte{0x2})
HMAC2(t2, prk[:], t1[:], []byte{0x3})
setZero(prk[:])
return
}
func isZero(val []byte) bool {
acc := 1
for _, b := range val {
acc &= subtle.ConstantTimeByteEq(b, 0)
}
return acc == 1
}
/* This function is not used as pervasively as it should because this is mostly impossible in Go at the moment */
func setZero(arr []byte) {
for i := range arr {
arr[i] = 0
}
}
func (sk *NoisePrivateKey) clamp() {
sk[0] &= 248
sk[31] = (sk[31] & 127) | 64
}
func newPrivateKey() (sk NoisePrivateKey, err error) {
_, err = rand.Read(sk[:])
sk.clamp()
return
}
func (sk *NoisePrivateKey) publicKey() (pk NoisePublicKey) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarBaseMult(apk, ask)
return
}
func (sk *NoisePrivateKey) sharedSecret(pk NoisePublicKey) (ss [NoisePublicKeySize]byte) {
apk := (*[NoisePublicKeySize]byte)(&pk)
ask := (*[NoisePrivateKeySize]byte)(sk)
curve25519.ScalarMult(&ss, ask, apk)
return ss
}

600
device/noise-protocol.go Normal file
View File

@@ -0,0 +1,600 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"errors"
"golang.org/x/crypto/blake2s"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/crypto/poly1305"
"golang.zx2c4.com/wireguard/tai64n"
"sync"
"time"
)
const (
HandshakeZeroed = iota
HandshakeInitiationCreated
HandshakeInitiationConsumed
HandshakeResponseCreated
HandshakeResponseConsumed
)
const (
NoiseConstruction = "Noise_IKpsk2_25519_ChaChaPoly_BLAKE2s"
WGIdentifier = "WireGuard v1 zx2c4 Jason@zx2c4.com"
WGLabelMAC1 = "mac1----"
WGLabelCookie = "cookie--"
)
const (
MessageInitiationType = 1
MessageResponseType = 2
MessageCookieReplyType = 3
MessageTransportType = 4
)
const (
MessageInitiationSize = 148 // size of handshake initation message
MessageResponseSize = 92 // size of response message
MessageCookieReplySize = 64 // size of cookie reply message
MessageTransportHeaderSize = 16 // size of data preceeding content in transport message
MessageTransportSize = MessageTransportHeaderSize + poly1305.TagSize // size of empty transport
MessageKeepaliveSize = MessageTransportSize // size of keepalive
MessageHandshakeSize = MessageInitiationSize // size of largest handshake releated message
)
const (
MessageTransportOffsetReceiver = 4
MessageTransportOffsetCounter = 8
MessageTransportOffsetContent = 16
)
/* Type is an 8-bit field, followed by 3 nul bytes,
* by marshalling the messages in little-endian byteorder
* we can treat these as a 32-bit unsigned int (for now)
*
*/
type MessageInitiation struct {
Type uint32
Sender uint32
Ephemeral NoisePublicKey
Static [NoisePublicKeySize + poly1305.TagSize]byte
Timestamp [tai64n.TimestampSize + poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte
}
type MessageResponse struct {
Type uint32
Sender uint32
Receiver uint32
Ephemeral NoisePublicKey
Empty [poly1305.TagSize]byte
MAC1 [blake2s.Size128]byte
MAC2 [blake2s.Size128]byte
}
type MessageTransport struct {
Type uint32
Receiver uint32
Counter uint64
Content []byte
}
type MessageCookieReply struct {
Type uint32
Receiver uint32
Nonce [chacha20poly1305.NonceSizeX]byte
Cookie [blake2s.Size128 + poly1305.TagSize]byte
}
type Handshake struct {
state int
mutex sync.RWMutex
hash [blake2s.Size]byte // hash value
chainKey [blake2s.Size]byte // chain key
presharedKey NoiseSymmetricKey // psk
localEphemeral NoisePrivateKey // ephemeral secret key
localIndex uint32 // used to clear hash-table
remoteIndex uint32 // index for sending
remoteStatic NoisePublicKey // long term key
remoteEphemeral NoisePublicKey // ephemeral public key
precomputedStaticStatic [NoisePublicKeySize]byte // precomputed shared secret
lastTimestamp tai64n.Timestamp
lastInitiationConsumption time.Time
lastSentHandshake time.Time
}
var (
InitialChainKey [blake2s.Size]byte
InitialHash [blake2s.Size]byte
ZeroNonce [chacha20poly1305.NonceSize]byte
)
func mixKey(dst *[blake2s.Size]byte, c *[blake2s.Size]byte, data []byte) {
KDF1(dst, c[:], data)
}
func mixHash(dst *[blake2s.Size]byte, h *[blake2s.Size]byte, data []byte) {
hash, _ := blake2s.New256(nil)
hash.Write(h[:])
hash.Write(data)
hash.Sum(dst[:0])
hash.Reset()
}
func (h *Handshake) Clear() {
setZero(h.localEphemeral[:])
setZero(h.remoteEphemeral[:])
setZero(h.chainKey[:])
setZero(h.hash[:])
h.localIndex = 0
h.state = HandshakeZeroed
}
func (h *Handshake) mixHash(data []byte) {
mixHash(&h.hash, &h.hash, data)
}
func (h *Handshake) mixKey(data []byte) {
mixKey(&h.chainKey, &h.chainKey, data)
}
/* Do basic precomputations
*/
func init() {
InitialChainKey = blake2s.Sum256([]byte(NoiseConstruction))
mixHash(&InitialHash, &InitialChainKey, []byte(WGIdentifier))
}
func (device *Device) CreateMessageInitiation(peer *Peer) (*MessageInitiation, error) {
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if isZero(handshake.precomputedStaticStatic[:]) {
return nil, errors.New("static shared secret is zero")
}
// create ephemeral key
var err error
handshake.hash = InitialHash
handshake.chainKey = InitialChainKey
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
// assign index
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
handshake.mixHash(handshake.remoteStatic[:])
msg := MessageInitiation{
Type: MessageInitiationType,
Ephemeral: handshake.localEphemeral.publicKey(),
Sender: handshake.localIndex,
}
handshake.mixKey(msg.Ephemeral[:])
handshake.mixHash(msg.Ephemeral[:])
// encrypt static key
func() {
var key [chacha20poly1305.KeySize]byte
ss := handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
ss[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Static[:0], ZeroNonce[:], device.staticIdentity.publicKey[:], handshake.hash[:])
}()
handshake.mixHash(msg.Static[:])
// encrypt timestamp
timestamp := tai64n.Now()
func() {
var key [chacha20poly1305.KeySize]byte
KDF2(
&handshake.chainKey,
&key,
handshake.chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Timestamp[:0], ZeroNonce[:], timestamp[:], handshake.hash[:])
}()
handshake.mixHash(msg.Timestamp[:])
handshake.state = HandshakeInitiationCreated
return &msg, nil
}
func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation) *Peer {
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
if msg.Type != MessageInitiationType {
return nil
}
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:])
mixHash(&hash, &hash, msg.Ephemeral[:])
mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:])
// decrypt static key
var err error
var peerPK NoisePublicKey
func() {
var key [chacha20poly1305.KeySize]byte
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
KDF2(&chainKey, &key, chainKey[:], ss[:])
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(peerPK[:0], ZeroNonce[:], msg.Static[:], hash[:])
}()
if err != nil {
return nil
}
mixHash(&hash, &hash, msg.Static[:])
// lookup peer
peer := device.LookupPeer(peerPK)
if peer == nil {
return nil
}
handshake := &peer.handshake
if isZero(handshake.precomputedStaticStatic[:]) {
return nil
}
// verify identity
var timestamp tai64n.Timestamp
var key [chacha20poly1305.KeySize]byte
handshake.mutex.RLock()
KDF2(
&chainKey,
&key,
chainKey[:],
handshake.precomputedStaticStatic[:],
)
aead, _ := chacha20poly1305.New(key[:])
_, err = aead.Open(timestamp[:0], ZeroNonce[:], msg.Timestamp[:], hash[:])
if err != nil {
handshake.mutex.RUnlock()
return nil
}
mixHash(&hash, &hash, msg.Timestamp[:])
// protect against replay & flood
var ok bool
ok = timestamp.After(handshake.lastTimestamp)
ok = ok && time.Now().Sub(handshake.lastInitiationConsumption) > HandshakeInitationRate
handshake.mutex.RUnlock()
if !ok {
return nil
}
// update handshake state
handshake.mutex.Lock()
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.remoteEphemeral = msg.Ephemeral
handshake.lastTimestamp = timestamp
handshake.lastInitiationConsumption = time.Now()
handshake.state = HandshakeInitiationConsumed
handshake.mutex.Unlock()
setZero(hash[:])
setZero(chainKey[:])
return peer
}
func (device *Device) CreateMessageResponse(peer *Peer) (*MessageResponse, error) {
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
if handshake.state != HandshakeInitiationConsumed {
return nil, errors.New("handshake initiation must be consumed first")
}
// assign index
var err error
device.indexTable.Delete(handshake.localIndex)
handshake.localIndex, err = device.indexTable.NewIndexForHandshake(peer, handshake)
if err != nil {
return nil, err
}
var msg MessageResponse
msg.Type = MessageResponseType
msg.Sender = handshake.localIndex
msg.Receiver = handshake.remoteIndex
// create ephemeral key
handshake.localEphemeral, err = newPrivateKey()
if err != nil {
return nil, err
}
msg.Ephemeral = handshake.localEphemeral.publicKey()
handshake.mixHash(msg.Ephemeral[:])
handshake.mixKey(msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(handshake.remoteEphemeral)
handshake.mixKey(ss[:])
ss = handshake.localEphemeral.sharedSecret(handshake.remoteStatic)
handshake.mixKey(ss[:])
}()
// add preshared key
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
KDF3(
&handshake.chainKey,
&tau,
&key,
handshake.chainKey[:],
handshake.presharedKey[:],
)
handshake.mixHash(tau[:])
func() {
aead, _ := chacha20poly1305.New(key[:])
aead.Seal(msg.Empty[:0], ZeroNonce[:], nil, handshake.hash[:])
handshake.mixHash(msg.Empty[:])
}()
handshake.state = HandshakeResponseCreated
return &msg, nil
}
func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer {
if msg.Type != MessageResponseType {
return nil
}
// lookup handshake by receiver
lookup := device.indexTable.Lookup(msg.Receiver)
handshake := lookup.handshake
if handshake == nil {
return nil
}
var (
hash [blake2s.Size]byte
chainKey [blake2s.Size]byte
)
ok := func() bool {
// lock handshake state
handshake.mutex.RLock()
defer handshake.mutex.RUnlock()
if handshake.state != HandshakeInitiationCreated {
return false
}
// lock private key for reading
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
// finish 3-way DH
mixHash(&hash, &handshake.hash, msg.Ephemeral[:])
mixKey(&chainKey, &handshake.chainKey, msg.Ephemeral[:])
func() {
ss := handshake.localEphemeral.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
func() {
ss := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral)
mixKey(&chainKey, &chainKey, ss[:])
setZero(ss[:])
}()
// add preshared key (psk)
var tau [blake2s.Size]byte
var key [chacha20poly1305.KeySize]byte
KDF3(
&chainKey,
&tau,
&key,
chainKey[:],
handshake.presharedKey[:],
)
mixHash(&hash, &hash, tau[:])
// authenticate transcript
aead, _ := chacha20poly1305.New(key[:])
_, err := aead.Open(nil, ZeroNonce[:], msg.Empty[:], hash[:])
if err != nil {
return false
}
mixHash(&hash, &hash, msg.Empty[:])
return true
}()
if !ok {
return nil
}
// update handshake state
handshake.mutex.Lock()
handshake.hash = hash
handshake.chainKey = chainKey
handshake.remoteIndex = msg.Sender
handshake.state = HandshakeResponseConsumed
handshake.mutex.Unlock()
setZero(hash[:])
setZero(chainKey[:])
return lookup.peer
}
/* Derives a new keypair from the current handshake state
*
*/
func (peer *Peer) BeginSymmetricSession() error {
device := peer.device
handshake := &peer.handshake
handshake.mutex.Lock()
defer handshake.mutex.Unlock()
// derive keys
var isInitiator bool
var sendKey [chacha20poly1305.KeySize]byte
var recvKey [chacha20poly1305.KeySize]byte
if handshake.state == HandshakeResponseConsumed {
KDF2(
&sendKey,
&recvKey,
handshake.chainKey[:],
nil,
)
isInitiator = true
} else if handshake.state == HandshakeResponseCreated {
KDF2(
&recvKey,
&sendKey,
handshake.chainKey[:],
nil,
)
isInitiator = false
} else {
return errors.New("invalid state for keypair derivation")
}
// zero handshake
setZero(handshake.chainKey[:])
setZero(handshake.hash[:]) // Doesn't necessarily need to be zeroed. Could be used for something interesting down the line.
setZero(handshake.localEphemeral[:])
peer.handshake.state = HandshakeZeroed
// create AEAD instances
keypair := new(Keypair)
keypair.send, _ = chacha20poly1305.New(sendKey[:])
keypair.receive, _ = chacha20poly1305.New(recvKey[:])
setZero(sendKey[:])
setZero(recvKey[:])
keypair.created = time.Now()
keypair.sendNonce = 0
keypair.replayFilter.Init()
keypair.isInitiator = isInitiator
keypair.localIndex = peer.handshake.localIndex
keypair.remoteIndex = peer.handshake.remoteIndex
// remap index
device.indexTable.SwapIndexForKeypair(handshake.localIndex, keypair)
handshake.localIndex = 0
// rotate key pairs
keypairs := &peer.keypairs
keypairs.Lock()
defer keypairs.Unlock()
previous := keypairs.previous
next := keypairs.next
current := keypairs.current
if isInitiator {
if next != nil {
keypairs.next = nil
keypairs.previous = next
device.DeleteKeypair(current)
} else {
keypairs.previous = current
}
device.DeleteKeypair(previous)
keypairs.current = keypair
} else {
keypairs.next = keypair
device.DeleteKeypair(next)
keypairs.previous = nil
device.DeleteKeypair(previous)
}
return nil
}
func (peer *Peer) ReceivedWithKeypair(receivedKeypair *Keypair) bool {
keypairs := &peer.keypairs
if keypairs.next != receivedKeypair {
return false
}
keypairs.Lock()
defer keypairs.Unlock()
if keypairs.next != receivedKeypair {
return false
}
old := keypairs.previous
keypairs.previous = keypairs.current
peer.device.DeleteKeypair(old)
keypairs.current = keypairs.next
keypairs.next = nil
return true
}

81
device/noise-types.go Normal file
View File

@@ -0,0 +1,81 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"crypto/subtle"
"encoding/hex"
"errors"
"golang.org/x/crypto/chacha20poly1305"
)
const (
NoisePublicKeySize = 32
NoisePrivateKeySize = 32
)
type (
NoisePublicKey [NoisePublicKeySize]byte
NoisePrivateKey [NoisePrivateKeySize]byte
NoiseSymmetricKey [chacha20poly1305.KeySize]byte
NoiseNonce uint64 // padded to 12-bytes
)
func loadExactHex(dst []byte, src string) error {
slice, err := hex.DecodeString(src)
if err != nil {
return err
}
if len(slice) != len(dst) {
return errors.New("hex string does not fit the slice")
}
copy(dst, slice)
return nil
}
func (key NoisePrivateKey) IsZero() bool {
var zero NoisePrivateKey
return key.Equals(zero)
}
func (key NoisePrivateKey) Equals(tar NoisePrivateKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
func (key *NoisePrivateKey) FromHex(src string) (err error) {
err = loadExactHex(key[:], src)
key.clamp()
return
}
func (key NoisePrivateKey) ToHex() string {
return hex.EncodeToString(key[:])
}
func (key *NoisePublicKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
func (key NoisePublicKey) ToHex() string {
return hex.EncodeToString(key[:])
}
func (key NoisePublicKey) IsZero() bool {
var zero NoisePublicKey
return key.Equals(zero)
}
func (key NoisePublicKey) Equals(tar NoisePublicKey) bool {
return subtle.ConstantTimeCompare(key[:], tar[:]) == 1
}
func (key *NoiseSymmetricKey) FromHex(src string) error {
return loadExactHex(key[:], src)
}
func (key NoiseSymmetricKey) ToHex() string {
return hex.EncodeToString(key[:])
}

144
device/noise_test.go Normal file
View File

@@ -0,0 +1,144 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/binary"
"testing"
)
func TestCurveWrappers(t *testing.T) {
sk1, err := newPrivateKey()
assertNil(t, err)
sk2, err := newPrivateKey()
assertNil(t, err)
pk1 := sk1.publicKey()
pk2 := sk2.publicKey()
ss1 := sk1.sharedSecret(pk2)
ss2 := sk2.sharedSecret(pk1)
if ss1 != ss2 {
t.Fatal("Failed to compute shared secet")
}
}
func TestNoiseHandshake(t *testing.T) {
dev1 := randDevice(t)
dev2 := randDevice(t)
defer dev1.Close()
defer dev2.Close()
peer1, _ := dev2.NewPeer(dev1.staticIdentity.privateKey.publicKey())
peer2, _ := dev1.NewPeer(dev2.staticIdentity.privateKey.publicKey())
assertEqual(
t,
peer1.handshake.precomputedStaticStatic[:],
peer2.handshake.precomputedStaticStatic[:],
)
/* simulate handshake */
// initiation message
t.Log("exchange initiation message")
msg1, err := dev1.CreateMessageInitiation(peer2)
assertNil(t, err)
packet := make([]byte, 0, 256)
writer := bytes.NewBuffer(packet)
err = binary.Write(writer, binary.LittleEndian, msg1)
assertNil(t, err)
peer := dev2.ConsumeMessageInitiation(msg1)
if peer == nil {
t.Fatal("handshake failed at initiation message")
}
assertEqual(
t,
peer1.handshake.chainKey[:],
peer2.handshake.chainKey[:],
)
assertEqual(
t,
peer1.handshake.hash[:],
peer2.handshake.hash[:],
)
// response message
t.Log("exchange response message")
msg2, err := dev2.CreateMessageResponse(peer1)
assertNil(t, err)
peer = dev1.ConsumeMessageResponse(msg2)
if peer == nil {
t.Fatal("handshake failed at response message")
}
assertEqual(
t,
peer1.handshake.chainKey[:],
peer2.handshake.chainKey[:],
)
assertEqual(
t,
peer1.handshake.hash[:],
peer2.handshake.hash[:],
)
// key pairs
t.Log("deriving keys")
err = peer1.BeginSymmetricSession()
if err != nil {
t.Fatal("failed to derive keypair for peer 1", err)
}
err = peer2.BeginSymmetricSession()
if err != nil {
t.Fatal("failed to derive keypair for peer 2", err)
}
key1 := peer1.keypairs.next
key2 := peer2.keypairs.current
// encrypting / decryption test
t.Log("test key pairs")
func() {
testMsg := []byte("wireguard test message 1")
var err error
var out []byte
var nonce [12]byte
out = key1.send.Seal(out, nonce[:], testMsg, nil)
out, err = key2.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
func() {
testMsg := []byte("wireguard test message 2")
var err error
var out []byte
var nonce [12]byte
out = key2.send.Seal(out, nonce[:], testMsg, nil)
out, err = key1.receive.Open(out[:0], nonce[:], out, nil)
assertNil(t, err)
assertEqual(t, out, testMsg)
}()
}

270
device/peer.go Normal file
View File

@@ -0,0 +1,270 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"encoding/base64"
"errors"
"fmt"
"sync"
"time"
)
const (
PeerRoutineNumber = 3
)
type Peer struct {
isRunning AtomicBool
sync.RWMutex // Mostly protects endpoint, but is generally taken whenever we modify peer
keypairs Keypairs
handshake Handshake
device *Device
endpoint Endpoint
persistentKeepaliveInterval uint16
// This must be 64-bit aligned, so make sure the above members come out to even alignment and pad accordingly
stats struct {
txBytes uint64 // bytes send to peer (endpoint)
rxBytes uint64 // bytes received from peer
lastHandshakeNano int64 // nano seconds since epoch
}
timers struct {
retransmitHandshake *Timer
sendKeepalive *Timer
newHandshake *Timer
zeroKeyMaterial *Timer
persistentKeepalive *Timer
handshakeAttempts uint32
needAnotherKeepalive AtomicBool
sentLastMinuteHandshake AtomicBool
}
signals struct {
newKeypairArrived chan struct{}
flushNonceQueue chan struct{}
}
queue struct {
nonce chan *QueueOutboundElement // nonce / pre-handshake queue
outbound chan *QueueOutboundElement // sequential ordering of work
inbound chan *QueueInboundElement // sequential ordering of work
packetInNonceQueueIsAwaitingKey AtomicBool
}
routines struct {
sync.Mutex // held when stopping / starting routines
starting sync.WaitGroup // routines pending start
stopping sync.WaitGroup // routines pending stop
stop chan struct{} // size 0, stop all go routines in peer
}
cookieGenerator CookieGenerator
}
func (device *Device) NewPeer(pk NoisePublicKey) (*Peer, error) {
if device.isClosed.Get() {
return nil, errors.New("device closed")
}
// lock resources
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
device.peers.Lock()
defer device.peers.Unlock()
// check if over limit
if len(device.peers.keyMap) >= MaxPeers {
return nil, errors.New("too many peers")
}
// create peer
peer := new(Peer)
peer.Lock()
defer peer.Unlock()
peer.cookieGenerator.Init(pk)
peer.device = device
peer.isRunning.Set(false)
// map public key
_, ok := device.peers.keyMap[pk]
if ok {
return nil, errors.New("adding existing peer")
}
device.peers.keyMap[pk] = peer
// pre-compute DH
handshake := &peer.handshake
handshake.mutex.Lock()
handshake.remoteStatic = pk
handshake.precomputedStaticStatic = device.staticIdentity.privateKey.sharedSecret(pk)
handshake.mutex.Unlock()
// reset endpoint
peer.endpoint = nil
// start peer
if peer.device.isUp.Get() {
peer.Start()
}
return peer, nil
}
func (peer *Peer) SendBuffer(buffer []byte) error {
peer.device.net.RLock()
defer peer.device.net.RUnlock()
if peer.device.net.bind == nil {
return errors.New("no bind")
}
peer.RLock()
defer peer.RUnlock()
if peer.endpoint == nil {
return errors.New("no known endpoint for peer")
}
return peer.device.net.bind.Send(buffer, peer.endpoint)
}
func (peer *Peer) String() string {
base64Key := base64.StdEncoding.EncodeToString(peer.handshake.remoteStatic[:])
abbreviatedKey := "invalid"
if len(base64Key) == 44 {
abbreviatedKey = base64Key[0:4] + "…" + base64Key[39:43]
}
return fmt.Sprintf("peer(%s)", abbreviatedKey)
}
func (peer *Peer) Start() {
// should never start a peer on a closed device
if peer.device.isClosed.Get() {
return
}
// prevent simultaneous start/stop operations
peer.routines.Lock()
defer peer.routines.Unlock()
if peer.isRunning.Get() {
return
}
device := peer.device
device.log.Debug.Println(peer, "- Starting...")
// reset routine state
peer.routines.starting.Wait()
peer.routines.stopping.Wait()
peer.routines.stop = make(chan struct{})
peer.routines.starting.Add(PeerRoutineNumber)
peer.routines.stopping.Add(PeerRoutineNumber)
// prepare queues
peer.queue.nonce = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.outbound = make(chan *QueueOutboundElement, QueueOutboundSize)
peer.queue.inbound = make(chan *QueueInboundElement, QueueInboundSize)
peer.timersInit()
peer.handshake.lastSentHandshake = time.Now().Add(-(RekeyTimeout + time.Second))
peer.signals.newKeypairArrived = make(chan struct{}, 1)
peer.signals.flushNonceQueue = make(chan struct{}, 1)
// wait for routines to start
go peer.RoutineNonce()
go peer.RoutineSequentialSender()
go peer.RoutineSequentialReceiver()
peer.routines.starting.Wait()
peer.isRunning.Set(true)
}
func (peer *Peer) ZeroAndFlushAll() {
device := peer.device
// clear key pairs
keypairs := &peer.keypairs
keypairs.Lock()
device.DeleteKeypair(keypairs.previous)
device.DeleteKeypair(keypairs.current)
device.DeleteKeypair(keypairs.next)
keypairs.previous = nil
keypairs.current = nil
keypairs.next = nil
keypairs.Unlock()
// clear handshake state
handshake := &peer.handshake
handshake.mutex.Lock()
device.indexTable.Delete(handshake.localIndex)
handshake.Clear()
handshake.mutex.Unlock()
peer.FlushNonceQueue()
}
func (peer *Peer) Stop() {
// prevent simultaneous start/stop operations
if !peer.isRunning.Swap(false) {
return
}
peer.routines.starting.Wait()
peer.routines.Lock()
defer peer.routines.Unlock()
peer.device.log.Debug.Println(peer, "- Stopping...")
peer.timersStop()
// stop & wait for ongoing peer routines
close(peer.routines.stop)
peer.routines.stopping.Wait()
// close queues
close(peer.queue.nonce)
close(peer.queue.outbound)
close(peer.queue.inbound)
peer.ZeroAndFlushAll()
}
var roamingDisabled bool
func (peer *Peer) SetEndpointFromPacket(endpoint Endpoint) {
if roamingDisabled {
return
}
peer.Lock()
peer.endpoint = endpoint
peer.Unlock()
}

89
device/pools.go Normal file
View File

@@ -0,0 +1,89 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import "sync"
func (device *Device) PopulatePools() {
if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool = &sync.Pool{
New: func() interface{} {
return new([MaxMessageSize]byte)
},
}
device.pool.inboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueInboundElement)
},
}
device.pool.outboundElementPool = &sync.Pool{
New: func() interface{} {
return new(QueueOutboundElement)
},
}
} else {
device.pool.messageBufferReuseChan = make(chan *[MaxMessageSize]byte, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.messageBufferReuseChan <- new([MaxMessageSize]byte)
}
device.pool.inboundElementReuseChan = make(chan *QueueInboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.inboundElementReuseChan <- new(QueueInboundElement)
}
device.pool.outboundElementReuseChan = make(chan *QueueOutboundElement, PreallocatedBuffersPerPool)
for i := 0; i < PreallocatedBuffersPerPool; i += 1 {
device.pool.outboundElementReuseChan <- new(QueueOutboundElement)
}
}
}
func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
if PreallocatedBuffersPerPool == 0 {
return device.pool.messageBufferPool.Get().(*[MaxMessageSize]byte)
} else {
return <-device.pool.messageBufferReuseChan
}
}
func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
if PreallocatedBuffersPerPool == 0 {
device.pool.messageBufferPool.Put(msg)
} else {
device.pool.messageBufferReuseChan <- msg
}
}
func (device *Device) GetInboundElement() *QueueInboundElement {
if PreallocatedBuffersPerPool == 0 {
return device.pool.inboundElementPool.Get().(*QueueInboundElement)
} else {
return <-device.pool.inboundElementReuseChan
}
}
func (device *Device) PutInboundElement(msg *QueueInboundElement) {
if PreallocatedBuffersPerPool == 0 {
device.pool.inboundElementPool.Put(msg)
} else {
device.pool.inboundElementReuseChan <- msg
}
}
func (device *Device) GetOutboundElement() *QueueOutboundElement {
if PreallocatedBuffersPerPool == 0 {
return device.pool.outboundElementPool.Get().(*QueueOutboundElement)
} else {
return <-device.pool.outboundElementReuseChan
}
}
func (device *Device) PutOutboundElement(msg *QueueOutboundElement) {
if PreallocatedBuffersPerPool == 0 {
device.pool.outboundElementPool.Put(msg)
} else {
device.pool.outboundElementReuseChan <- msg
}
}

16
device/queueconstants.go Normal file
View File

@@ -0,0 +1,16 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
/* Implementation specific constants */
const (
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
)

641
device/receive.go Normal file
View File

@@ -0,0 +1,641 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"strconv"
"sync"
"sync/atomic"
"time"
)
type QueueHandshakeElement struct {
msgType uint32
packet []byte
endpoint Endpoint
buffer *[MaxMessageSize]byte
}
type QueueInboundElement struct {
dropped int32
sync.Mutex
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keypair *Keypair
endpoint Endpoint
}
func (elem *QueueInboundElement) Drop() {
atomic.StoreInt32(&elem.dropped, AtomicTrue)
}
func (elem *QueueInboundElement) IsDropped() bool {
return atomic.LoadInt32(&elem.dropped) == AtomicTrue
}
func (device *Device) addToInboundAndDecryptionQueues(inboundQueue chan *QueueInboundElement, decryptionQueue chan *QueueInboundElement, element *QueueInboundElement) bool {
select {
case inboundQueue <- element:
select {
case decryptionQueue <- element:
return true
default:
element.Drop()
element.Unlock()
return false
}
default:
device.PutInboundElement(element)
return false
}
}
func (device *Device) addToHandshakeQueue(queue chan QueueHandshakeElement, element QueueHandshakeElement) bool {
select {
case queue <- element:
return true
default:
return false
}
}
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake.Get() {
return
}
keypair := peer.keypairs.Current()
if keypair != nil && keypair.isInitiator && time.Now().Sub(keypair.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake.Set(true)
peer.SendHandshakeInitiation(false)
}
}
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately)
*/
func (device *Device) RoutineReceiveIncoming(IP int, bind Bind) {
logDebug := device.log.Debug
defer func() {
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - stopped")
device.net.stopping.Done()
}()
logDebug.Println("Routine: receive incoming IPv" + strconv.Itoa(IP) + " - started")
device.net.starting.Done()
// receive datagrams until conn is closed
buffer := device.GetMessageBuffer()
var (
err error
size int
endpoint Endpoint
)
for {
// read next datagram
switch IP {
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
}
if err != nil {
device.PutMessageBuffer(buffer)
return
}
if size < MinMessageSize {
continue
}
// check size of packet
packet := buffer[:size]
msgType := binary.LittleEndian.Uint32(packet[:4])
var okay bool
switch msgType {
// check if transport
case MessageTransportType:
// check size
if len(packet) < MessageTransportSize {
continue
}
// lookup key pair
receiver := binary.LittleEndian.Uint32(
packet[MessageTransportOffsetReceiver:MessageTransportOffsetCounter],
)
value := device.indexTable.Lookup(receiver)
keypair := value.keypair
if keypair == nil {
continue
}
// check keypair expiry
if keypair.created.Add(RejectAfterTime).Before(time.Now()) {
continue
}
// create work element
peer := value.peer
elem := device.GetInboundElement()
elem.packet = packet
elem.buffer = buffer
elem.keypair = keypair
elem.dropped = AtomicFalse
elem.endpoint = endpoint
elem.counter = 0
elem.Mutex = sync.Mutex{}
elem.Lock()
// add to decryption queues
if peer.isRunning.Get() {
if device.addToInboundAndDecryptionQueues(peer.queue.inbound, device.queue.decryption, elem) {
buffer = device.GetMessageBuffer()
}
}
continue
// otherwise it is a fixed size & handshake related packet
case MessageInitiationType:
okay = len(packet) == MessageInitiationSize
case MessageResponseType:
okay = len(packet) == MessageResponseSize
case MessageCookieReplyType:
okay = len(packet) == MessageCookieReplySize
default:
logDebug.Println("Received message with unknown type")
}
if okay {
if (device.addToHandshakeQueue(
device.queue.handshake,
QueueHandshakeElement{
msgType: msgType,
buffer: buffer,
packet: packet,
endpoint: endpoint,
},
)) {
buffer = device.GetMessageBuffer()
}
}
}
}
func (device *Device) RoutineDecryption() {
var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug
defer func() {
logDebug.Println("Routine: decryption worker - stopped")
device.state.stopping.Done()
}()
logDebug.Println("Routine: decryption worker - started")
device.state.starting.Done()
for {
select {
case <-device.signals.stop:
return
case elem, ok := <-device.queue.decryption:
if !ok {
return
}
// check if dropped
if elem.IsDropped() {
continue
}
// split message into fields
counter := elem.packet[MessageTransportOffsetCounter:MessageTransportOffsetContent]
content := elem.packet[MessageTransportOffsetContent:]
// expand nonce
nonce[0x4] = counter[0x0]
nonce[0x5] = counter[0x1]
nonce[0x6] = counter[0x2]
nonce[0x7] = counter[0x3]
nonce[0x8] = counter[0x4]
nonce[0x9] = counter[0x5]
nonce[0xa] = counter[0x6]
nonce[0xb] = counter[0x7]
// decrypt and release to consumer
var err error
elem.counter = binary.LittleEndian.Uint64(counter)
elem.packet, err = elem.keypair.receive.Open(
content[:0],
nonce[:],
content,
nil,
)
if err != nil {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
}
elem.Unlock()
}
}
}
/* Handles incoming packets related to handshake
*/
func (device *Device) RoutineHandshake() {
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
var elem QueueHandshakeElement
var ok bool
defer func() {
logDebug.Println("Routine: handshake worker - stopped")
device.state.stopping.Done()
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
}
}()
logDebug.Println("Routine: handshake worker - started")
device.state.starting.Done()
for {
if elem.buffer != nil {
device.PutMessageBuffer(elem.buffer)
elem.buffer = nil
}
select {
case elem, ok = <-device.queue.handshake:
case <-device.signals.stop:
return
}
if !ok {
return
}
// handle cookie fields and ratelimiting
switch elem.msgType {
case MessageCookieReplyType:
// unmarshal packet
var reply MessageCookieReply
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &reply)
if err != nil {
logDebug.Println("Failed to decode cookie reply")
return
}
// lookup peer from index
entry := device.indexTable.Lookup(reply.Receiver)
if entry.peer == nil {
continue
}
// consume reply
if peer := entry.peer; peer.isRunning.Get() {
logDebug.Println("Receiving cookie response from ", elem.endpoint.DstToString())
if !peer.cookieGenerator.ConsumeReply(&reply) {
logDebug.Println("Could not decrypt invalid cookie response")
}
}
continue
case MessageInitiationType, MessageResponseType:
// check mac fields and maybe ratelimit
if !device.cookieChecker.CheckMAC1(elem.packet) {
logDebug.Println("Received packet with invalid mac1")
continue
}
// endpoints destination address is the source of the datagram
if device.IsUnderLoad() {
// verify MAC2 field
if !device.cookieChecker.CheckMAC2(elem.packet, elem.endpoint.DstToBytes()) {
device.SendHandshakeCookie(&elem)
continue
}
// check ratelimiter
if !device.rate.limiter.Allow(elem.endpoint.DstIP()) {
continue
}
}
default:
logError.Println("Invalid packet ended up in the handshake queue")
continue
}
// handle handshake initiation/response content
switch elem.msgType {
case MessageInitiationType:
// unmarshal
var msg MessageInitiation
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
logError.Println("Failed to decode initiation message")
continue
}
// consume initiation
peer := device.ConsumeMessageInitiation(&msg)
if peer == nil {
logInfo.Println(
"Received invalid initiation message from",
elem.endpoint.DstToString(),
)
continue
}
// update timers
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
logDebug.Println(peer, "- Received handshake initiation")
peer.SendHandshakeResponse()
case MessageResponseType:
// unmarshal
var msg MessageResponse
reader := bytes.NewReader(elem.packet)
err := binary.Read(reader, binary.LittleEndian, &msg)
if err != nil {
logError.Println("Failed to decode response message")
continue
}
// consume response
peer := device.ConsumeMessageResponse(&msg)
if peer == nil {
logInfo.Println(
"Received invalid response message from",
elem.endpoint.DstToString(),
)
continue
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
logDebug.Println(peer, "- Received handshake response")
// update timers
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// derive keypair
err = peer.BeginSymmetricSession()
if err != nil {
logError.Println(peer, "- Failed to derive keypair:", err)
continue
}
peer.timersSessionDerived()
peer.timersHandshakeComplete()
peer.SendKeepalive()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
}
}
}
func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
logInfo := device.log.Info
logError := device.log.Error
logDebug := device.log.Debug
var elem *QueueInboundElement
var ok bool
defer func() {
logDebug.Println(peer, "- Routine: sequential receiver - stopped")
peer.routines.stopping.Done()
if elem != nil {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
}
device.PutInboundElement(elem)
}
}()
logDebug.Println(peer, "- Routine: sequential receiver - started")
peer.routines.starting.Done()
for {
if elem != nil {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
}
device.PutInboundElement(elem)
elem = nil
}
select {
case <-peer.routines.stop:
return
case elem, ok = <-peer.queue.inbound:
if !ok {
return
}
// wait for decryption
elem.Lock()
if elem.IsDropped() {
continue
}
// check for replay
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
// check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
}
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// check for keepalive
if len(elem.packet) == 0 {
logDebug.Println(peer, "- Receiving keepalive packet")
continue
}
peer.timersDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
}
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
logInfo.Println(
"IPv4 packet with disallowed source address from",
peer,
)
continue
}
case ipv6.Version:
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
continue
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
logInfo.Println(
peer,
"sent packet with disallowed IPv6 source",
)
continue
}
default:
logInfo.Println("Packet with invalid IP version from", peer)
continue
}
// write to tun device
offset := MessageTransportOffsetContent
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)))
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
if err != nil {
logError.Println("Failed to write packet to TUN device:", err)
}
}
}
}

618
device/send.go Normal file
View File

@@ -0,0 +1,618 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bytes"
"encoding/binary"
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"net"
"sync"
"sync/atomic"
"time"
)
/* Outbound flow
*
* 1. TUN queue
* 2. Routing (sequential)
* 3. Nonce assignment (sequential)
* 4. Encryption (parallel)
* 5. Transmission (sequential)
*
* The functions in this file occur (roughly) in the order in
* which the packets are processed.
*
* Locking, Producers and Consumers
*
* The order of packets (per peer) must be maintained,
* but encryption of packets happen out-of-order:
*
* The sequential consumers will attempt to take the lock,
* workers release lock when they have completed work (encryption) on the packet.
*
* If the element is inserted into the "encryption queue",
* the content is preceded by enough "junk" to contain the transport header
* (to allow the construction of transport messages in-place)
*/
type QueueOutboundElement struct {
dropped int32
sync.Mutex
buffer *[MaxMessageSize]byte // slice holding the packet data
packet []byte // slice of "buffer" (always!)
nonce uint64 // nonce for encryption
keypair *Keypair // keypair for encryption
peer *Peer // related peer
}
func (device *Device) NewOutboundElement() *QueueOutboundElement {
elem := device.GetOutboundElement()
elem.dropped = AtomicFalse
elem.buffer = device.GetMessageBuffer()
elem.Mutex = sync.Mutex{}
elem.nonce = 0
elem.keypair = nil
elem.peer = nil
return elem
}
func (elem *QueueOutboundElement) Drop() {
atomic.StoreInt32(&elem.dropped, AtomicTrue)
}
func (elem *QueueOutboundElement) IsDropped() bool {
return atomic.LoadInt32(&elem.dropped) == AtomicTrue
}
func addToNonceQueue(queue chan *QueueOutboundElement, element *QueueOutboundElement, device *Device) {
for {
select {
case queue <- element:
return
default:
select {
case old := <-queue:
device.PutMessageBuffer(old.buffer)
device.PutOutboundElement(old)
default:
}
}
}
}
func addToOutboundAndEncryptionQueues(outboundQueue chan *QueueOutboundElement, encryptionQueue chan *QueueOutboundElement, element *QueueOutboundElement) {
select {
case outboundQueue <- element:
select {
case encryptionQueue <- element:
return
default:
element.Drop()
element.peer.device.PutMessageBuffer(element.buffer)
element.Unlock()
}
default:
element.peer.device.PutMessageBuffer(element.buffer)
element.peer.device.PutOutboundElement(element)
}
}
/* Queues a keepalive if no packets are queued for peer
*/
func (peer *Peer) SendKeepalive() bool {
if len(peer.queue.nonce) != 0 || peer.queue.packetInNonceQueueIsAwaitingKey.Get() || !peer.isRunning.Get() {
return false
}
elem := peer.device.NewOutboundElement()
elem.packet = nil
select {
case peer.queue.nonce <- elem:
peer.device.log.Debug.Println(peer, "- Sending keepalive packet")
return true
default:
peer.device.PutMessageBuffer(elem.buffer)
peer.device.PutOutboundElement(elem)
return false
}
}
func (peer *Peer) SendHandshakeInitiation(isRetry bool) error {
if !isRetry {
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
}
peer.handshake.mutex.RLock()
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.RUnlock()
return nil
}
peer.handshake.mutex.RUnlock()
peer.handshake.mutex.Lock()
if time.Now().Sub(peer.handshake.lastSentHandshake) < RekeyTimeout {
peer.handshake.mutex.Unlock()
return nil
}
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
peer.device.log.Debug.Println(peer, "- Sending handshake initiation")
msg, err := peer.device.CreateMessageInitiation(peer)
if err != nil {
peer.device.log.Error.Println(peer, "- Failed to create initiation message:", err)
return err
}
var buff [MessageInitiationSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, msg)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffer(packet)
if err != nil {
peer.device.log.Error.Println(peer, "- Failed to send handshake initiation", err)
}
peer.timersHandshakeInitiated()
return err
}
func (peer *Peer) SendHandshakeResponse() error {
peer.handshake.mutex.Lock()
peer.handshake.lastSentHandshake = time.Now()
peer.handshake.mutex.Unlock()
peer.device.log.Debug.Println(peer, "- Sending handshake response")
response, err := peer.device.CreateMessageResponse(peer)
if err != nil {
peer.device.log.Error.Println(peer, "- Failed to create response message:", err)
return err
}
var buff [MessageResponseSize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, response)
packet := writer.Bytes()
peer.cookieGenerator.AddMacs(packet)
err = peer.BeginSymmetricSession()
if err != nil {
peer.device.log.Error.Println(peer, "- Failed to derive keypair:", err)
return err
}
peer.timersSessionDerived()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
err = peer.SendBuffer(packet)
if err != nil {
peer.device.log.Error.Println(peer, "- Failed to send handshake response", err)
}
return err
}
func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) error {
device.log.Debug.Println("Sending cookie response for denied handshake message for", initiatingElem.endpoint.DstToString())
sender := binary.LittleEndian.Uint32(initiatingElem.packet[4:8])
reply, err := device.cookieChecker.CreateReply(initiatingElem.packet, sender, initiatingElem.endpoint.DstToBytes())
if err != nil {
device.log.Error.Println("Failed to create cookie reply:", err)
return err
}
var buff [MessageCookieReplySize]byte
writer := bytes.NewBuffer(buff[:0])
binary.Write(writer, binary.LittleEndian, reply)
device.net.bind.Send(writer.Bytes(), initiatingElem.endpoint)
if err != nil {
device.log.Error.Println("Failed to send cookie reply:", err)
}
return err
}
func (peer *Peer) keepKeyFreshSending() {
keypair := peer.keypairs.Current()
if keypair == nil {
return
}
nonce := atomic.LoadUint64(&keypair.sendNonce)
if nonce > RekeyAfterMessages || (keypair.isInitiator && time.Now().Sub(keypair.created) > RekeyAfterTime) {
peer.SendHandshakeInitiation(false)
}
}
/* Reads packets from the TUN and inserts
* into nonce queue for peer
*
* Obs. Single instance per TUN device
*/
func (device *Device) RoutineReadFromTUN() {
logDebug := device.log.Debug
logError := device.log.Error
defer func() {
logDebug.Println("Routine: TUN reader - stopped")
device.state.stopping.Done()
}()
logDebug.Println("Routine: TUN reader - started")
device.state.starting.Done()
var elem *QueueOutboundElement
for {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
}
elem = device.NewOutboundElement()
// read packet
offset := MessageTransportHeaderSize
size, err := device.tun.device.Read(elem.buffer[:], offset)
if err != nil {
if !device.isClosed.Get() {
logError.Println("Failed to read packet from TUN device:", err)
device.Close()
}
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return
}
if size == 0 || size > MaxContentSize {
continue
}
elem.packet = elem.buffer[offset : offset+size]
// lookup peer
var peer *Peer
switch elem.packet[0] >> 4 {
case ipv4.Version:
if len(elem.packet) < ipv4.HeaderLen {
continue
}
dst := elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len]
peer = device.allowedips.LookupIPv4(dst)
case ipv6.Version:
if len(elem.packet) < ipv6.HeaderLen {
continue
}
dst := elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len]
peer = device.allowedips.LookupIPv6(dst)
default:
logDebug.Println("Received packet with unknown IP version")
}
if peer == nil {
continue
}
// insert into nonce/pre-handshake queue
if peer.isRunning.Get() {
if peer.queue.packetInNonceQueueIsAwaitingKey.Get() {
peer.SendHandshakeInitiation(false)
}
addToNonceQueue(peer.queue.nonce, elem, device)
elem = nil
}
}
}
func (peer *Peer) FlushNonceQueue() {
select {
case peer.signals.flushNonceQueue <- struct{}{}:
default:
}
}
/* Queues packets when there is no handshake.
* Then assigns nonces to packets sequentially
* and creates "work" structs for workers
*
* Obs. A single instance per peer
*/
func (peer *Peer) RoutineNonce() {
var keypair *Keypair
device := peer.device
logDebug := device.log.Debug
flush := func() {
for {
select {
case elem := <-peer.queue.nonce:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
default:
return
}
}
}
defer func() {
flush()
logDebug.Println(peer, "- Routine: nonce worker - stopped")
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
peer.routines.stopping.Done()
}()
peer.routines.starting.Done()
logDebug.Println(peer, "- Routine: nonce worker - started")
for {
NextPacket:
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
select {
case <-peer.routines.stop:
return
case <-peer.signals.flushNonceQueue:
flush()
goto NextPacket
case elem, ok := <-peer.queue.nonce:
if !ok {
return
}
// make sure to always pick the newest key
for {
// check validity of newest key pair
keypair = peer.keypairs.Current()
if keypair != nil && keypair.sendNonce < RejectAfterMessages {
if time.Now().Sub(keypair.created) < RejectAfterTime {
break
}
}
peer.queue.packetInNonceQueueIsAwaitingKey.Set(true)
// no suitable key pair, request for new handshake
select {
case <-peer.signals.newKeypairArrived:
default:
}
peer.SendHandshakeInitiation(false)
// wait for key to be established
logDebug.Println(peer, "- Awaiting keypair")
select {
case <-peer.signals.newKeypairArrived:
logDebug.Println(peer, "- Obtained awaited keypair")
case <-peer.signals.flushNonceQueue:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
flush()
goto NextPacket
case <-peer.routines.stop:
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
return
}
}
peer.queue.packetInNonceQueueIsAwaitingKey.Set(false)
// populate work element
elem.peer = peer
elem.nonce = atomic.AddUint64(&keypair.sendNonce, 1) - 1
// double check in case of race condition added by future code
if elem.nonce >= RejectAfterMessages {
atomic.StoreUint64(&keypair.sendNonce, RejectAfterMessages)
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
goto NextPacket
}
elem.keypair = keypair
elem.dropped = AtomicFalse
elem.Lock()
// add to parallel and sequential queue
addToOutboundAndEncryptionQueues(peer.queue.outbound, device.queue.encryption, elem)
}
}
}
/* Encrypts the elements in the queue
* and marks them for sequential consumption (by releasing the mutex)
*
* Obs. One instance per core
*/
func (device *Device) RoutineEncryption() {
var nonce [chacha20poly1305.NonceSize]byte
logDebug := device.log.Debug
defer func() {
for {
select {
case elem, ok := <-device.queue.encryption:
if ok && !elem.IsDropped() {
elem.Drop()
device.PutMessageBuffer(elem.buffer)
elem.Unlock()
}
default:
goto out
}
}
out:
logDebug.Println("Routine: encryption worker - stopped")
device.state.stopping.Done()
}()
logDebug.Println("Routine: encryption worker - started")
device.state.starting.Done()
for {
// fetch next element
select {
case <-device.signals.stop:
return
case elem, ok := <-device.queue.encryption:
if !ok {
return
}
// check if dropped
if elem.IsDropped() {
continue
}
// populate header fields
header := elem.buffer[:MessageTransportHeaderSize]
fieldType := header[0:4]
fieldReceiver := header[4:8]
fieldNonce := header[8:16]
binary.LittleEndian.PutUint32(fieldType, MessageTransportType)
binary.LittleEndian.PutUint32(fieldReceiver, elem.keypair.remoteIndex)
binary.LittleEndian.PutUint64(fieldNonce, elem.nonce)
// pad content to multiple of 16
mtu := int(atomic.LoadInt32(&device.tun.mtu))
lastUnit := len(elem.packet) % mtu
paddedSize := (lastUnit + PaddingMultiple - 1) & ^(PaddingMultiple - 1)
if paddedSize > mtu {
paddedSize = mtu
}
for i := len(elem.packet); i < paddedSize; i++ {
elem.packet = append(elem.packet, 0)
}
// encrypt content and release to consumer
binary.LittleEndian.PutUint64(nonce[4:], elem.nonce)
elem.packet = elem.keypair.send.Seal(
header,
nonce[:],
elem.packet,
nil,
)
elem.Unlock()
}
}
}
/* Sequentially reads packets from queue and sends to endpoint
*
* Obs. Single instance per peer.
* The routine terminates then the outbound queue is closed.
*/
func (peer *Peer) RoutineSequentialSender() {
device := peer.device
logDebug := device.log.Debug
logError := device.log.Error
defer func() {
for {
select {
case elem, ok := <-peer.queue.outbound:
if ok {
if !elem.IsDropped() {
device.PutMessageBuffer(elem.buffer)
elem.Drop()
}
device.PutOutboundElement(elem)
}
default:
goto out
}
}
out:
logDebug.Println(peer, "- Routine: sequential sender - stopped")
peer.routines.stopping.Done()
}()
logDebug.Println(peer, "- Routine: sequential sender - started")
peer.routines.starting.Done()
for {
select {
case <-peer.routines.stop:
return
case elem, ok := <-peer.queue.outbound:
if !ok {
return
}
elem.Lock()
if elem.IsDropped() {
device.PutOutboundElement(elem)
continue
}
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketSent()
// send message and return buffer to pool
length := uint64(len(elem.packet))
err := peer.SendBuffer(elem.packet)
device.PutMessageBuffer(elem.buffer)
device.PutOutboundElement(elem)
if err != nil {
logError.Println(peer, "- Failed to send data packet", err)
continue
}
atomic.AddUint64(&peer.stats.txBytes, length)
if len(elem.packet) != MessageKeepaliveSize {
peer.timersDataSent()
}
peer.keepKeyFreshSending()
}
}
}

227
device/timers.go Normal file
View File

@@ -0,0 +1,227 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*
* This is based heavily on timers.c from the kernel implementation.
*/
package device
import (
"math/rand"
"sync"
"sync/atomic"
"time"
)
/* This Timer structure and related functions should roughly copy the interface of
* the Linux kernel's struct timer_list.
*/
type Timer struct {
*time.Timer
modifyingLock sync.RWMutex
runningLock sync.Mutex
isPending bool
}
func (peer *Peer) NewTimer(expirationFunction func(*Peer)) *Timer {
timer := &Timer{}
timer.Timer = time.AfterFunc(time.Hour, func() {
timer.runningLock.Lock()
timer.modifyingLock.Lock()
if !timer.isPending {
timer.modifyingLock.Unlock()
timer.runningLock.Unlock()
return
}
timer.isPending = false
timer.modifyingLock.Unlock()
expirationFunction(peer)
timer.runningLock.Unlock()
})
timer.Stop()
return timer
}
func (timer *Timer) Mod(d time.Duration) {
timer.modifyingLock.Lock()
timer.isPending = true
timer.Reset(d)
timer.modifyingLock.Unlock()
}
func (timer *Timer) Del() {
timer.modifyingLock.Lock()
timer.isPending = false
timer.Stop()
timer.modifyingLock.Unlock()
}
func (timer *Timer) DelSync() {
timer.Del()
timer.runningLock.Lock()
timer.Del()
timer.runningLock.Unlock()
}
func (timer *Timer) IsPending() bool {
timer.modifyingLock.RLock()
defer timer.modifyingLock.RUnlock()
return timer.isPending
}
func (peer *Peer) timersActive() bool {
return peer.isRunning.Get() && peer.device != nil && peer.device.isUp.Get() && len(peer.device.peers.keyMap) > 0
}
func expiredRetransmitHandshake(peer *Peer) {
if atomic.LoadUint32(&peer.timers.handshakeAttempts) > MaxTimerHandshakes {
peer.device.log.Debug.Printf("%s - Handshake did not complete after %d attempts, giving up\n", peer, MaxTimerHandshakes+2)
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
}
/* We drop all packets without a keypair and don't try again,
* if we try unsuccessfully for too long to make a handshake.
*/
peer.FlushNonceQueue()
/* We set a timer for destroying any residue that might be left
* of a partial exchange.
*/
if peer.timersActive() && !peer.timers.zeroKeyMaterial.IsPending() {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
} else {
atomic.AddUint32(&peer.timers.handshakeAttempts, 1)
peer.device.log.Debug.Printf("%s - Handshake did not complete after %d seconds, retrying (try %d)\n", peer, int(RekeyTimeout.Seconds()), atomic.LoadUint32(&peer.timers.handshakeAttempts)+1)
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(true)
}
}
func expiredSendKeepalive(peer *Peer) {
peer.SendKeepalive()
if peer.timers.needAnotherKeepalive.Get() {
peer.timers.needAnotherKeepalive.Set(false)
if peer.timersActive() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
}
}
}
func expiredNewHandshake(peer *Peer) {
peer.device.log.Debug.Printf("%s - Retrying handshake because we stopped hearing back after %d seconds\n", peer, int((KeepaliveTimeout + RekeyTimeout).Seconds()))
/* We clear the endpoint address src address, in case this is the cause of trouble. */
peer.Lock()
if peer.endpoint != nil {
peer.endpoint.ClearSrc()
}
peer.Unlock()
peer.SendHandshakeInitiation(false)
}
func expiredZeroKeyMaterial(peer *Peer) {
peer.device.log.Debug.Printf("%s - Removing all keys, since we haven't received a new one in %d seconds\n", peer, int((RejectAfterTime * 3).Seconds()))
peer.ZeroAndFlushAll()
}
func expiredPersistentKeepalive(peer *Peer) {
if peer.persistentKeepaliveInterval > 0 {
peer.SendKeepalive()
}
}
/* Should be called after an authenticated data packet is sent. */
func (peer *Peer) timersDataSent() {
if peer.timersActive() && !peer.timers.newHandshake.IsPending() {
peer.timers.newHandshake.Mod(KeepaliveTimeout + RekeyTimeout)
}
}
/* Should be called after an authenticated data packet is received. */
func (peer *Peer) timersDataReceived() {
if peer.timersActive() {
if !peer.timers.sendKeepalive.IsPending() {
peer.timers.sendKeepalive.Mod(KeepaliveTimeout)
} else {
peer.timers.needAnotherKeepalive.Set(true)
}
}
}
/* Should be called after any type of authenticated packet is sent -- keepalive, data, or handshake. */
func (peer *Peer) timersAnyAuthenticatedPacketSent() {
if peer.timersActive() {
peer.timers.sendKeepalive.Del()
}
}
/* Should be called after any type of authenticated packet is received -- keepalive, data, or handshake. */
func (peer *Peer) timersAnyAuthenticatedPacketReceived() {
if peer.timersActive() {
peer.timers.newHandshake.Del()
}
}
/* Should be called after a handshake initiation message is sent. */
func (peer *Peer) timersHandshakeInitiated() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Mod(RekeyTimeout + time.Millisecond*time.Duration(rand.Int31n(RekeyTimeoutJitterMaxMs)))
}
}
/* Should be called after a handshake response message is received and processed or when getting key confirmation via the first data message. */
func (peer *Peer) timersHandshakeComplete() {
if peer.timersActive() {
peer.timers.retransmitHandshake.Del()
}
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake.Set(false)
atomic.StoreInt64(&peer.stats.lastHandshakeNano, time.Now().UnixNano())
}
/* Should be called after an ephemeral key is created, which is before sending a handshake response or after receiving a handshake response. */
func (peer *Peer) timersSessionDerived() {
if peer.timersActive() {
peer.timers.zeroKeyMaterial.Mod(RejectAfterTime * 3)
}
}
/* Should be called before a packet with authentication -- keepalive, data, or handshake -- is sent, or after one is received. */
func (peer *Peer) timersAnyAuthenticatedPacketTraversal() {
if peer.persistentKeepaliveInterval > 0 && peer.timersActive() {
peer.timers.persistentKeepalive.Mod(time.Duration(peer.persistentKeepaliveInterval) * time.Second)
}
}
func (peer *Peer) timersInit() {
peer.timers.retransmitHandshake = peer.NewTimer(expiredRetransmitHandshake)
peer.timers.sendKeepalive = peer.NewTimer(expiredSendKeepalive)
peer.timers.newHandshake = peer.NewTimer(expiredNewHandshake)
peer.timers.zeroKeyMaterial = peer.NewTimer(expiredZeroKeyMaterial)
peer.timers.persistentKeepalive = peer.NewTimer(expiredPersistentKeepalive)
atomic.StoreUint32(&peer.timers.handshakeAttempts, 0)
peer.timers.sentLastMinuteHandshake.Set(false)
peer.timers.needAnotherKeepalive.Set(false)
}
func (peer *Peer) timersStop() {
peer.timers.retransmitHandshake.DelSync()
peer.timers.sendKeepalive.DelSync()
peer.timers.newHandshake.DelSync()
peer.timers.zeroKeyMaterial.DelSync()
peer.timers.persistentKeepalive.DelSync()
}

55
device/tun.go Normal file
View File

@@ -0,0 +1,55 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"golang.zx2c4.com/wireguard/tun"
"sync/atomic"
)
const DefaultMTU = 1420
func (device *Device) RoutineTUNEventReader() {
setUp := false
logDebug := device.log.Debug
logInfo := device.log.Info
logError := device.log.Error
logDebug.Println("Routine: event worker - started")
device.state.starting.Done()
for event := range device.tun.device.Events() {
if event&tun.TUNEventMTUUpdate != 0 {
mtu, err := device.tun.device.MTU()
old := atomic.LoadInt32(&device.tun.mtu)
if err != nil {
logError.Println("Failed to load updated MTU of device:", err)
} else if int(old) != mtu {
if mtu+MessageTransportSize > MaxMessageSize {
logInfo.Println("MTU updated:", mtu, "(too large)")
} else {
logInfo.Println("MTU updated:", mtu)
}
atomic.StoreInt32(&device.tun.mtu, int32(mtu))
}
}
if event&tun.TUNEventUp != 0 && !setUp {
logInfo.Println("Interface set up")
setUp = true
device.Up()
}
if event&tun.TUNEventDown != 0 && setUp {
logInfo.Println("Interface set down")
setUp = false
device.Down()
}
}
logDebug.Println("Routine: event worker - stopped")
device.state.stopping.Done()
}

426
device/uapi.go Normal file
View File

@@ -0,0 +1,426 @@
/* SPDX-License-Identifier: MIT
*
* Copyright (C) 2017-2019 WireGuard LLC. All Rights Reserved.
*/
package device
import (
"bufio"
"fmt"
"golang.zx2c4.com/wireguard/ipc"
"io"
"net"
"strconv"
"strings"
"sync/atomic"
"time"
)
type IPCError struct {
int64
}
func (s *IPCError) Error() string {
return fmt.Sprintf("IPC error: %d", s.int64)
}
func (s *IPCError) ErrorCode() int64 {
return s.int64
}
func (device *Device) IpcGetOperation(socket *bufio.Writer) *IPCError {
device.log.Debug.Println("UAPI: Processing get operation")
// create lines
lines := make([]string, 0, 100)
send := func(line string) {
lines = append(lines, line)
}
func() {
// lock required resources
device.net.RLock()
defer device.net.RUnlock()
device.staticIdentity.RLock()
defer device.staticIdentity.RUnlock()
device.peers.RLock()
defer device.peers.RUnlock()
// serialize device related values
if !device.staticIdentity.privateKey.IsZero() {
send("private_key=" + device.staticIdentity.privateKey.ToHex())
}
if device.net.port != 0 {
send(fmt.Sprintf("listen_port=%d", device.net.port))
}
if device.net.fwmark != 0 {
send(fmt.Sprintf("fwmark=%d", device.net.fwmark))
}
// serialize each peer state
for _, peer := range device.peers.keyMap {
peer.RLock()
defer peer.RUnlock()
send("public_key=" + peer.handshake.remoteStatic.ToHex())
send("preshared_key=" + peer.handshake.presharedKey.ToHex())
send("protocol_version=1")
if peer.endpoint != nil {
send("endpoint=" + peer.endpoint.DstToString())
}
nano := atomic.LoadInt64(&peer.stats.lastHandshakeNano)
secs := nano / time.Second.Nanoseconds()
nano %= time.Second.Nanoseconds()
send(fmt.Sprintf("last_handshake_time_sec=%d", secs))
send(fmt.Sprintf("last_handshake_time_nsec=%d", nano))
send(fmt.Sprintf("tx_bytes=%d", atomic.LoadUint64(&peer.stats.txBytes)))
send(fmt.Sprintf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes)))
send(fmt.Sprintf("persistent_keepalive_interval=%d", peer.persistentKeepaliveInterval))
for _, ip := range device.allowedips.EntriesForPeer(peer) {
send("allowed_ip=" + ip.String())
}
}
}()
// send lines (does not require resource locks)
for _, line := range lines {
_, err := socket.WriteString(line + "\n")
if err != nil {
return &IPCError{ipc.IpcErrorIO}
}
}
return nil
}
func (device *Device) IpcSetOperation(socket *bufio.Reader) *IPCError {
scanner := bufio.NewScanner(socket)
logError := device.log.Error
logDebug := device.log.Debug
var peer *Peer
dummy := false
deviceConfig := true
for scanner.Scan() {
// parse line
line := scanner.Text()
if line == "" {
return nil
}
parts := strings.Split(line, "=")
if len(parts) != 2 {
return &IPCError{ipc.IpcErrorProtocol}
}
key := parts[0]
value := parts[1]
/* device configuration */
if deviceConfig {
switch key {
case "private_key":
var sk NoisePrivateKey
err := sk.FromHex(value)
if err != nil {
logError.Println("Failed to set private_key:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println("UAPI: Updating private key")
device.SetPrivateKey(sk)
case "listen_port":
// parse port number
port, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to parse listen_port:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
// update port and rebind
logDebug.Println("UAPI: Updating listen port")
device.net.Lock()
device.net.port = uint16(port)
device.net.Unlock()
if err := device.BindUpdate(); err != nil {
logError.Println("Failed to set listen_port:", err)
return &IPCError{ipc.IpcErrorPortInUse}
}
case "fwmark":
// parse fwmark field
fwmark, err := func() (uint32, error) {
if value == "" {
return 0, nil
}
mark, err := strconv.ParseUint(value, 10, 32)
return uint32(mark), err
}()
if err != nil {
logError.Println("Invalid fwmark", err)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println("UAPI: Updating fwmark")
if err := device.BindSetMark(uint32(fwmark)); err != nil {
logError.Println("Failed to update fwmark:", err)
return &IPCError{ipc.IpcErrorPortInUse}
}
case "public_key":
// switch to peer configuration
logDebug.Println("UAPI: Transition to peer configuration")
deviceConfig = false
case "replace_peers":
if value != "true" {
logError.Println("Failed to set replace_peers, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println("UAPI: Removing all peers")
device.RemoveAllPeers()
default:
logError.Println("Invalid UAPI device key:", key)
return &IPCError{ipc.IpcErrorInvalid}
}
}
/* peer configuration */
if !deviceConfig {
switch key {
case "public_key":
var publicKey NoisePublicKey
err := publicKey.FromHex(value)
if err != nil {
logError.Println("Failed to get peer by public key:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
// ignore peer with public key of device
device.staticIdentity.RLock()
dummy = device.staticIdentity.publicKey.Equals(publicKey)
device.staticIdentity.RUnlock()
if dummy {
peer = &Peer{}
} else {
peer = device.LookupPeer(publicKey)
}
if peer == nil {
peer, err = device.NewPeer(publicKey)
if err != nil {
logError.Println("Failed to create new peer:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
logDebug.Println(peer, "- UAPI: Created")
}
case "remove":
// remove currently selected peer from device
if value != "true" {
logError.Println("Failed to set remove, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if !dummy {
logDebug.Println(peer, "- UAPI: Removing")
device.RemovePeer(peer.handshake.remoteStatic)
}
peer = &Peer{}
dummy = true
case "preshared_key":
// update PSK
logDebug.Println(peer, "- UAPI: Updating preshared key")
peer.handshake.mutex.Lock()
err := peer.handshake.presharedKey.FromHex(value)
peer.handshake.mutex.Unlock()
if err != nil {
logError.Println("Failed to set preshared key:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
case "endpoint":
// set endpoint destination
logDebug.Println(peer, "- UAPI: Updating endpoint")
err := func() error {
peer.Lock()
defer peer.Unlock()
endpoint, err := CreateEndpoint(value)
if err != nil {
return err
}
peer.endpoint = endpoint
return nil
}()
if err != nil {
logError.Println("Failed to set endpoint:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
case "persistent_keepalive_interval":
// update persistent keepalive interval
logDebug.Println(peer, "- UAPI: Updating persistent keepalive interval")
secs, err := strconv.ParseUint(value, 10, 16)
if err != nil {
logError.Println("Failed to set persistent keepalive interval:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
old := peer.persistentKeepaliveInterval
peer.persistentKeepaliveInterval = uint16(secs)
// send immediate keepalive if we're turning it on and before it wasn't on
if old == 0 && secs != 0 {
if err != nil {
logError.Println("Failed to get tun device status:", err)
return &IPCError{ipc.IpcErrorIO}
}
if device.isUp.Get() && !dummy {
peer.SendKeepalive()
}
}
case "replace_allowed_ips":
logDebug.Println(peer, "- UAPI: Removing all allowedips")
if value != "true" {
logError.Println("Failed to replace allowedips, invalid value:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
if dummy {
continue
}
device.allowedips.RemoveByPeer(peer)
case "allowed_ip":
logDebug.Println(peer, "- UAPI: Adding allowedip")
_, network, err := net.ParseCIDR(value)
if err != nil {
logError.Println("Failed to set allowed ip:", err)
return &IPCError{ipc.IpcErrorInvalid}
}
if dummy {
continue
}
ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint(ones), peer)
case "protocol_version":
if value != "1" {
logError.Println("Invalid protocol version:", value)
return &IPCError{ipc.IpcErrorInvalid}
}
default:
logError.Println("Invalid UAPI peer key:", key)
return &IPCError{ipc.IpcErrorInvalid}
}
}
}
return nil
}
func (device *Device) IpcHandle(socket net.Conn) {
// create buffered read/writer
defer socket.Close()
buffered := func(s io.ReadWriter) *bufio.ReadWriter {
reader := bufio.NewReader(s)
writer := bufio.NewWriter(s)
return bufio.NewReadWriter(reader, writer)
}(socket)
defer buffered.Flush()
op, err := buffered.ReadString('\n')
if err != nil {
return
}
// handle operation
var status *IPCError
switch op {
case "set=1\n":
device.log.Debug.Println("UAPI: Set operation")
status = device.IpcSetOperation(buffered.Reader)
case "get=1\n":
device.log.Debug.Println("UAPI: Get operation")
status = device.IpcGetOperation(buffered.Writer)
default:
device.log.Error.Println("Invalid UAPI operation:", op)
return
}
// write status
if status != nil {
device.log.Error.Println(status)
fmt.Fprintf(buffered, "errno=%d\n\n", status.ErrorCode())
} else {
fmt.Fprintf(buffered, "errno=0\n\n")
}
}

3
device/version.go Normal file
View File

@@ -0,0 +1,3 @@
package device
const WireGuardGoVersion = "0.0.20181222"