global: use netip where possible now

There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld
2021-11-05 01:52:54 +01:00
parent de7c702ace
commit ef8d6804d7
22 changed files with 247 additions and 285 deletions

View File

@@ -12,6 +12,8 @@ import (
"net"
"sync"
"unsafe"
"golang.zx2c4.com/go118/netip"
)
type parentIndirection struct {
@@ -26,7 +28,7 @@ type trieEntry struct {
cidr uint8
bitAtByte uint8
bitAtShift uint8
bits net.IP
bits []byte
perPeerElem *list.Element
}
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
return bits.ReverseBytes64(i)
}
func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
func commonBits(ip1, ip2 []byte) uint8 {
size := len(ip1)
if size == net.IPv4len {
a := (*uint32)(unsafe.Pointer(&ip1[0]))
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
}
}
func (node *trieEntry) choose(ip net.IP) byte {
func (node *trieEntry) choose(ip []byte) byte {
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
}
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
node.parent.parentBit = nil
}
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
parent = node
if parent.cidr == cidr {
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
return
}
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
if *trie.parentBit == nil {
node := &trieEntry{
peer: peer,
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
}
}
func (node *trieEntry) lookup(ip net.IP) *Peer {
func (node *trieEntry) lookup(ip []byte) *Peer {
var found *Peer
size := uint8(len(ip))
for node != nil && commonBits(node.bits, ip) >= node.cidr {
@@ -229,13 +231,14 @@ type AllowedIPs struct {
mutex sync.RWMutex
}
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
table.mutex.RLock()
defer table.mutex.RUnlock()
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
node := elem.Value.(*trieEntry)
if !cb(node.bits, node.cidr) {
a, _ := netip.AddrFromSlice(node.bits)
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
return
}
}
@@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
}
}
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
table.mutex.Lock()
defer table.mutex.Unlock()
switch len(ip) {
case net.IPv6len:
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
case net.IPv4len:
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
default:
if prefix.Addr().Is6() {
ip := prefix.Addr().As16()
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else if prefix.Addr().Is4() {
ip := prefix.Addr().As4()
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
} else {
panic(errors.New("inserting unknown address type"))
}
}
func (table *AllowedIPs) Lookup(address []byte) *Peer {
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
table.mutex.RLock()
defer table.mutex.RUnlock()
switch len(address) {
switch len(ip) {
case net.IPv6len:
return table.IPv6.lookup(address)
return table.IPv6.lookup(ip)
case net.IPv4len:
return table.IPv4.lookup(address)
return table.IPv4.lookup(ip)
default:
panic(errors.New("looking up unknown address type"))
}

View File

@@ -10,6 +10,8 @@ import (
"net"
"sort"
"testing"
"golang.zx2c4.com/go118/netip"
)
const (
@@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
rand.Read(addr4[:])
cidr := uint8(rand.Intn(32) + 1)
index := rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr4[:], cidr, peers[index])
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
var addr6 [16]byte
rand.Read(addr6[:])
cidr = uint8(rand.Intn(128) + 1)
index = rand.Intn(NumberOfPeers)
allowedIPs.Insert(addr6[:], cidr, peers[index])
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
}

View File

@@ -9,6 +9,8 @@ import (
"math/rand"
"net"
"testing"
"golang.zx2c4.com/go118/netip"
)
type testPairCommonBits struct {
@@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
var allowedIPs AllowedIPs
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d byte) {
@@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
addr = append(addr, expand(b)...)
addr = append(addr, expand(c)...)
addr = append(addr, expand(d)...)
allowedIPs.Insert(addr, cidr, peer)
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
}
assertEQ := func(peer *Peer, a, b, c, d uint32) {

View File

@@ -11,7 +11,6 @@ import (
"fmt"
"io"
"math/rand"
"net"
"runtime"
"runtime/pprof"
"sync"
@@ -19,6 +18,7 @@ import (
"testing"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/conn/bindtest"
"golang.zx2c4.com/wireguard/tun/tuntest"
@@ -96,7 +96,7 @@ type testPair [2]testPeer
type testPeer struct {
tun *tuntest.ChannelTUN
dev *Device
ip net.IP
ip netip.Addr
}
type SendDirection bool
@@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
for i := range pair {
p := &pair[i]
p.tun = tuntest.NewChannelTUN()
p.ip = net.IPv4(1, 0, 0, byte(i+1))
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
level := LogLevelVerbose
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
level = LogLevelError

View File

@@ -7,47 +7,44 @@ package device
import (
"math/rand"
"net"
"golang.zx2c4.com/go118/netip"
)
type DummyEndpoint struct {
src [16]byte
dst [16]byte
src, dst netip.Addr
}
func CreateDummyEndpoint() (*DummyEndpoint, error) {
var end DummyEndpoint
if _, err := rand.Read(end.src[:]); err != nil {
var src, dst [16]byte
if _, err := rand.Read(src[:]); err != nil {
return nil, err
}
_, err := rand.Read(end.dst[:])
return &end, err
_, err := rand.Read(dst[:])
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
}
func (e *DummyEndpoint) ClearSrc() {}
func (e *DummyEndpoint) SrcToString() string {
var addr net.UDPAddr
addr.IP = e.SrcIP()
addr.Port = 1000
return addr.String()
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
}
func (e *DummyEndpoint) DstToString() string {
var addr net.UDPAddr
addr.IP = e.DstIP()
addr.Port = 1000
return addr.String()
return netip.AddrPortFrom(e.DstIP(), 1000).String()
}
func (e *DummyEndpoint) SrcToBytes() []byte {
return e.src[:]
func (e *DummyEndpoint) DstToBytes() []byte {
out := e.DstIP().AsSlice()
out = append(out, byte(1000&0xff))
out = append(out, byte((1000>>8)&0xff))
return out
}
func (e *DummyEndpoint) DstIP() net.IP {
return e.dst[:]
func (e *DummyEndpoint) DstIP() netip.Addr {
return e.dst
}
func (e *DummyEndpoint) SrcIP() net.IP {
return e.src[:]
func (e *DummyEndpoint) SrcIP() netip.Addr {
return e.src
}

View File

@@ -17,7 +17,6 @@ import (
"golang.org/x/crypto/chacha20poly1305"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn"
)

View File

@@ -18,6 +18,7 @@ import (
"sync/atomic"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/ipc"
)
@@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
sendf("allowed_ip=%s/%d", ip.String(), cidr)
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
sendf("allowed_ip=%s", prefix.String())
return true
})
}
@@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
case "allowed_ip":
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
_, network, err := net.ParseCIDR(value)
prefix, err := netip.ParsePrefix(value)
if err != nil {
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
}
if peer.dummy {
return nil
}
ones, _ := network.Mask.Size()
device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
device.allowedips.Insert(prefix, peer.Peer)
case "protocol_version":
if value != "1" {