global: use netip where possible now
There are more places where we'll need to add it later, when Go 1.18 comes out with support for it in the "net" package. Also, allowedips still uses slices internally, which might be suboptimal. Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -12,6 +12,8 @@ import (
|
||||
"net"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
)
|
||||
|
||||
type parentIndirection struct {
|
||||
@@ -26,7 +28,7 @@ type trieEntry struct {
|
||||
cidr uint8
|
||||
bitAtByte uint8
|
||||
bitAtShift uint8
|
||||
bits net.IP
|
||||
bits []byte
|
||||
perPeerElem *list.Element
|
||||
}
|
||||
|
||||
@@ -51,7 +53,7 @@ func swapU64(i uint64) uint64 {
|
||||
return bits.ReverseBytes64(i)
|
||||
}
|
||||
|
||||
func commonBits(ip1 net.IP, ip2 net.IP) uint8 {
|
||||
func commonBits(ip1, ip2 []byte) uint8 {
|
||||
size := len(ip1)
|
||||
if size == net.IPv4len {
|
||||
a := (*uint32)(unsafe.Pointer(&ip1[0]))
|
||||
@@ -85,7 +87,7 @@ func (node *trieEntry) removeFromPeerEntries() {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) choose(ip net.IP) byte {
|
||||
func (node *trieEntry) choose(ip []byte) byte {
|
||||
return (ip[node.bitAtByte] >> node.bitAtShift) & 1
|
||||
}
|
||||
|
||||
@@ -104,7 +106,7 @@ func (node *trieEntry) zeroizePointers() {
|
||||
node.parent.parentBit = nil
|
||||
}
|
||||
|
||||
func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
func (node *trieEntry) nodePlacement(ip []byte, cidr uint8) (parent *trieEntry, exact bool) {
|
||||
for node != nil && node.cidr <= cidr && commonBits(node.bits, ip) >= node.cidr {
|
||||
parent = node
|
||||
if parent.cidr == cidr {
|
||||
@@ -117,7 +119,7 @@ func (node *trieEntry) nodePlacement(ip net.IP, cidr uint8) (parent *trieEntry,
|
||||
return
|
||||
}
|
||||
|
||||
func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
func (trie parentIndirection) insert(ip []byte, cidr uint8, peer *Peer) {
|
||||
if *trie.parentBit == nil {
|
||||
node := &trieEntry{
|
||||
peer: peer,
|
||||
@@ -207,7 +209,7 @@ func (trie parentIndirection) insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (node *trieEntry) lookup(ip net.IP) *Peer {
|
||||
func (node *trieEntry) lookup(ip []byte) *Peer {
|
||||
var found *Peer
|
||||
size := uint8(len(ip))
|
||||
for node != nil && commonBits(node.bits, ip) >= node.cidr {
|
||||
@@ -229,13 +231,14 @@ type AllowedIPs struct {
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(ip net.IP, cidr uint8) bool) {
|
||||
func (table *AllowedIPs) EntriesForPeer(peer *Peer, cb func(prefix netip.Prefix) bool) {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
|
||||
for elem := peer.trieEntries.Front(); elem != nil; elem = elem.Next() {
|
||||
node := elem.Value.(*trieEntry)
|
||||
if !cb(node.bits, node.cidr) {
|
||||
a, _ := netip.AddrFromSlice(node.bits)
|
||||
if !cb(netip.PrefixFrom(a, int(node.cidr))) {
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -283,28 +286,29 @@ func (table *AllowedIPs) RemoveByPeer(peer *Peer) {
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Insert(ip net.IP, cidr uint8, peer *Peer) {
|
||||
func (table *AllowedIPs) Insert(prefix netip.Prefix, peer *Peer) {
|
||||
table.mutex.Lock()
|
||||
defer table.mutex.Unlock()
|
||||
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip, cidr, peer)
|
||||
case net.IPv4len:
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip, cidr, peer)
|
||||
default:
|
||||
if prefix.Addr().Is6() {
|
||||
ip := prefix.Addr().As16()
|
||||
parentIndirection{&table.IPv6, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else if prefix.Addr().Is4() {
|
||||
ip := prefix.Addr().As4()
|
||||
parentIndirection{&table.IPv4, 2}.insert(ip[:], uint8(prefix.Bits()), peer)
|
||||
} else {
|
||||
panic(errors.New("inserting unknown address type"))
|
||||
}
|
||||
}
|
||||
|
||||
func (table *AllowedIPs) Lookup(address []byte) *Peer {
|
||||
func (table *AllowedIPs) Lookup(ip []byte) *Peer {
|
||||
table.mutex.RLock()
|
||||
defer table.mutex.RUnlock()
|
||||
switch len(address) {
|
||||
switch len(ip) {
|
||||
case net.IPv6len:
|
||||
return table.IPv6.lookup(address)
|
||||
return table.IPv6.lookup(ip)
|
||||
case net.IPv4len:
|
||||
return table.IPv4.lookup(address)
|
||||
return table.IPv4.lookup(ip)
|
||||
default:
|
||||
panic(errors.New("looking up unknown address type"))
|
||||
}
|
||||
|
||||
@@ -10,6 +10,8 @@ import (
|
||||
"net"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -93,14 +95,14 @@ func TestTrieRandom(t *testing.T) {
|
||||
rand.Read(addr4[:])
|
||||
cidr := uint8(rand.Intn(32) + 1)
|
||||
index := rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(addr4[:], cidr, peers[index])
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4(addr4), int(cidr)), peers[index])
|
||||
slow4 = slow4.Insert(addr4[:], cidr, peers[index])
|
||||
|
||||
var addr6 [16]byte
|
||||
rand.Read(addr6[:])
|
||||
cidr = uint8(rand.Intn(128) + 1)
|
||||
index = rand.Intn(NumberOfPeers)
|
||||
allowedIPs.Insert(addr6[:], cidr, peers[index])
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(addr6), int(cidr)), peers[index])
|
||||
slow6 = slow6.Insert(addr6[:], cidr, peers[index])
|
||||
}
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ import (
|
||||
"math/rand"
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
)
|
||||
|
||||
type testPairCommonBits struct {
|
||||
@@ -98,7 +100,7 @@ func TestTrieIPv4(t *testing.T) {
|
||||
var allowedIPs AllowedIPs
|
||||
|
||||
insert := func(peer *Peer, a, b, c, d byte, cidr uint8) {
|
||||
allowedIPs.Insert([]byte{a, b, c, d}, cidr, peer)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom4([4]byte{a, b, c, d}), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d byte) {
|
||||
@@ -208,7 +210,7 @@ func TestTrieIPv6(t *testing.T) {
|
||||
addr = append(addr, expand(b)...)
|
||||
addr = append(addr, expand(c)...)
|
||||
addr = append(addr, expand(d)...)
|
||||
allowedIPs.Insert(addr, cidr, peer)
|
||||
allowedIPs.Insert(netip.PrefixFrom(netip.AddrFrom16(*(*[16]byte)(addr)), int(cidr)), peer)
|
||||
}
|
||||
|
||||
assertEQ := func(peer *Peer, a, b, c, d uint32) {
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"sync"
|
||||
@@ -19,6 +18,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/conn/bindtest"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
@@ -96,7 +96,7 @@ type testPair [2]testPeer
|
||||
type testPeer struct {
|
||||
tun *tuntest.ChannelTUN
|
||||
dev *Device
|
||||
ip net.IP
|
||||
ip netip.Addr
|
||||
}
|
||||
|
||||
type SendDirection bool
|
||||
@@ -159,7 +159,7 @@ func genTestPair(tb testing.TB, realSocket bool) (pair testPair) {
|
||||
for i := range pair {
|
||||
p := &pair[i]
|
||||
p.tun = tuntest.NewChannelTUN()
|
||||
p.ip = net.IPv4(1, 0, 0, byte(i+1))
|
||||
p.ip = netip.AddrFrom4([4]byte{1, 0, 0, byte(i + 1)})
|
||||
level := LogLevelVerbose
|
||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||
level = LogLevelError
|
||||
|
||||
@@ -7,47 +7,44 @@ package device
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"net"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
)
|
||||
|
||||
type DummyEndpoint struct {
|
||||
src [16]byte
|
||||
dst [16]byte
|
||||
src, dst netip.Addr
|
||||
}
|
||||
|
||||
func CreateDummyEndpoint() (*DummyEndpoint, error) {
|
||||
var end DummyEndpoint
|
||||
if _, err := rand.Read(end.src[:]); err != nil {
|
||||
var src, dst [16]byte
|
||||
if _, err := rand.Read(src[:]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err := rand.Read(end.dst[:])
|
||||
return &end, err
|
||||
_, err := rand.Read(dst[:])
|
||||
return &DummyEndpoint{netip.AddrFrom16(src), netip.AddrFrom16(dst)}, err
|
||||
}
|
||||
|
||||
func (e *DummyEndpoint) ClearSrc() {}
|
||||
|
||||
func (e *DummyEndpoint) SrcToString() string {
|
||||
var addr net.UDPAddr
|
||||
addr.IP = e.SrcIP()
|
||||
addr.Port = 1000
|
||||
return addr.String()
|
||||
return netip.AddrPortFrom(e.SrcIP(), 1000).String()
|
||||
}
|
||||
|
||||
func (e *DummyEndpoint) DstToString() string {
|
||||
var addr net.UDPAddr
|
||||
addr.IP = e.DstIP()
|
||||
addr.Port = 1000
|
||||
return addr.String()
|
||||
return netip.AddrPortFrom(e.DstIP(), 1000).String()
|
||||
}
|
||||
|
||||
func (e *DummyEndpoint) SrcToBytes() []byte {
|
||||
return e.src[:]
|
||||
func (e *DummyEndpoint) DstToBytes() []byte {
|
||||
out := e.DstIP().AsSlice()
|
||||
out = append(out, byte(1000&0xff))
|
||||
out = append(out, byte((1000>>8)&0xff))
|
||||
return out
|
||||
}
|
||||
|
||||
func (e *DummyEndpoint) DstIP() net.IP {
|
||||
return e.dst[:]
|
||||
func (e *DummyEndpoint) DstIP() netip.Addr {
|
||||
return e.dst
|
||||
}
|
||||
|
||||
func (e *DummyEndpoint) SrcIP() net.IP {
|
||||
return e.src[:]
|
||||
func (e *DummyEndpoint) SrcIP() netip.Addr {
|
||||
return e.src
|
||||
}
|
||||
|
||||
@@ -17,7 +17,6 @@ import (
|
||||
"golang.org/x/crypto/chacha20poly1305"
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/go118/netip"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
@@ -121,8 +122,8 @@ func (device *Device) IpcGetOperation(w io.Writer) error {
|
||||
sendf("rx_bytes=%d", atomic.LoadUint64(&peer.stats.rxBytes))
|
||||
sendf("persistent_keepalive_interval=%d", atomic.LoadUint32(&peer.persistentKeepaliveInterval))
|
||||
|
||||
device.allowedips.EntriesForPeer(peer, func(ip net.IP, cidr uint8) bool {
|
||||
sendf("allowed_ip=%s/%d", ip.String(), cidr)
|
||||
device.allowedips.EntriesForPeer(peer, func(prefix netip.Prefix) bool {
|
||||
sendf("allowed_ip=%s", prefix.String())
|
||||
return true
|
||||
})
|
||||
}
|
||||
@@ -374,16 +375,14 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||
|
||||
case "allowed_ip":
|
||||
device.log.Verbosef("%v - UAPI: Adding allowedip", peer.Peer)
|
||||
|
||||
_, network, err := net.ParseCIDR(value)
|
||||
prefix, err := netip.ParsePrefix(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set allowed ip: %w", err)
|
||||
}
|
||||
if peer.dummy {
|
||||
return nil
|
||||
}
|
||||
ones, _ := network.Mask.Size()
|
||||
device.allowedips.Insert(network.IP, uint8(ones), peer.Peer)
|
||||
device.allowedips.Insert(prefix, peer.Peer)
|
||||
|
||||
case "protocol_version":
|
||||
if value != "1" {
|
||||
|
||||
Reference in New Issue
Block a user