global: use netip where possible now

There are more places where we'll need to add it later, when Go 1.18
comes out with support for it in the "net" package. Also, allowedips
still uses slices internally, which might be suboptimal.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld
2021-11-05 01:52:54 +01:00
parent de7c702ace
commit ef8d6804d7
22 changed files with 247 additions and 285 deletions

View File

@@ -1,4 +1,5 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -10,9 +11,9 @@ package main
import (
"io"
"log"
"net"
"net/http"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +21,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8")},
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8")},
1420)
if err != nil {
log.Panic(err)

View File

@@ -1,4 +1,5 @@
//go:build ignore
// +build ignore
/* SPDX-License-Identifier: MIT
*
@@ -13,6 +14,7 @@ import (
"net"
"net/http"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
@@ -20,8 +22,8 @@ import (
func main() {
tun, tnet, err := netstack.CreateNetTUN(
[]net.IP{net.ParseIP("192.168.4.29")},
[]net.IP{net.ParseIP("8.8.8.8"), net.ParseIP("8.8.4.4")},
[]netip.Addr{netip.MustParseAddr("192.168.4.29")},
[]netip.Addr{netip.MustParseAddr("8.8.8.8"), netip.MustParseAddr("8.8.4.4")},
1420,
)
if err != nil {

View File

@@ -6,6 +6,7 @@ require (
golang.org/x/net v0.0.0-20210423184538-5f58ad60dda6
golang.org/x/sys v0.0.0-20210423185535-09eb48e85fd7 // indirect
golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22
gvisor.dev/gvisor v0.0.0-20211020211948-f76a604701b6
)

View File

@@ -805,6 +805,10 @@ golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8T
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5 h1:mV4w4F7AtWXoDNkko9odoTdWpNwyDh8jx+S1fOZKDLg=
golang.zx2c4.com/go118/netip v0.0.0-20211104120624-f0ae7a6e37c5/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53 h1:nFvpdzrHF9IPo9xPgayHWObCATpQYKky8VSSdt9lf9E=
golang.zx2c4.com/go118/netip v0.0.0-20211105124833-002a02cb0e53/go.mod h1:5yyfuiqVIJ7t+3MqrpTQ+QqRkMWiESiyDvPNvKYCecg=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22 h1:ytS28bw9HtZVDRMDxviC6ryCJuccw+zXhh04u2IRWJw=
golang.zx2c4.com/wireguard v0.0.0-20210424170727-c9db4b7aaa22/go.mod h1:a057zjmoc00UN7gVkaJt2sXVK523kMJcogDTEvPIasg=
google.golang.org/api v0.4.0/go.mod h1:8k5glujaEP+g9n7WNsDg8QP6cUVNI86fCNMcbazEtwE=

View File

@@ -18,6 +18,7 @@ import (
"strings"
"time"
"golang.zx2c4.com/go118/netip"
"golang.zx2c4.com/wireguard/tun"
"golang.org/x/net/dns/dnsmessage"
@@ -38,7 +39,7 @@ type netTun struct {
events chan tun.Event
incomingPacket chan buffer.VectorisedView
mtu int
dnsServers []net.IP
dnsServers []netip.Addr
hasV4, hasV6 bool
}
type endpoint netTun
@@ -94,7 +95,7 @@ func (*endpoint) ARPHardwareType() header.ARPHardwareType {
func (e *endpoint) AddHeader(tcpip.LinkAddress, tcpip.LinkAddress, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) {
}
func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Net, error) {
func CreateNetTUN(localAddresses, dnsServers []netip.Addr, mtu int) (tun.Device, *Net, error) {
opts := stack.Options{
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol, udp.NewProtocol},
@@ -112,25 +113,23 @@ func CreateNetTUN(localAddresses, dnsServers []net.IP, mtu int) (tun.Device, *Ne
return nil, nil, fmt.Errorf("CreateNIC: %v", tcpipErr)
}
for _, ip := range localAddresses {
if ip4 := ip.To4(); ip4 != nil {
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv4.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip4).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip4, tcpipErr)
}
var protoNumber tcpip.NetworkProtocolNumber
if ip.Is4() {
protoNumber = ipv4.ProtocolNumber
} else if ip.Is6() {
protoNumber = ipv6.ProtocolNumber
}
protoAddr := tcpip.ProtocolAddress{
Protocol: protoNumber,
AddressWithPrefix: tcpip.Address(ip.AsSlice()).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
if ip.Is4() {
dev.hasV4 = true
} else {
protoAddr := tcpip.ProtocolAddress{
Protocol: ipv6.ProtocolNumber,
AddressWithPrefix: tcpip.Address(ip).WithPrefix(),
}
tcpipErr := dev.stack.AddProtocolAddress(1, protoAddr, stack.AddressProperties{})
if tcpipErr != nil {
return nil, nil, fmt.Errorf("AddProtocolAddress(%v): %v", ip, tcpipErr)
}
} else if ip.Is6() {
dev.hasV6 = true
}
}
@@ -202,62 +201,83 @@ func (tun *netTun) MTU() (int, error) {
return tun.mtu, nil
}
func convertToFullAddr(ip net.IP, port int) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
if ip4 := ip.To4(); ip4 != nil {
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(ip4),
Port: uint16(port),
}, ipv4.ProtocolNumber
func convertToFullAddr(endpoint netip.AddrPort) (tcpip.FullAddress, tcpip.NetworkProtocolNumber) {
var protoNumber tcpip.NetworkProtocolNumber
if endpoint.Addr().Is4() {
protoNumber = ipv4.ProtocolNumber
} else {
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(ip),
Port: uint16(port),
}, ipv6.ProtocolNumber
protoNumber = ipv6.ProtocolNumber
}
return tcpip.FullAddress{
NIC: 1,
Addr: tcpip.Address(endpoint.Addr().AsSlice()),
Port: endpoint.Port(),
}, protoNumber
}
func (net *Net) DialContextTCPAddrPort(ctx context.Context, addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
}
func (net *Net) DialContextTCP(ctx context.Context, addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
panic("todo: deal with auto addr semantics for nil addr")
return net.DialContextTCPAddrPort(ctx, netip.AddrPort{})
}
fa, pn := convertToFullAddr(addr.IP, addr.Port)
return gonet.DialContextTCP(ctx, net.stack, fa, pn)
return net.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) DialTCPAddrPort(addr netip.AddrPort) (*gonet.TCPConn, error) {
fa, pn := convertToFullAddr(addr)
return gonet.DialTCP(net.stack, fa, pn)
}
func (net *Net) DialTCP(addr *net.TCPAddr) (*gonet.TCPConn, error) {
if addr == nil {
panic("todo: deal with auto addr semantics for nil addr")
return net.DialTCPAddrPort(netip.AddrPort{})
}
fa, pn := convertToFullAddr(addr.IP, addr.Port)
return gonet.DialTCP(net.stack, fa, pn)
return net.DialTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) ListenTCPAddrPort(addr netip.AddrPort) (*gonet.TCPListener, error) {
fa, pn := convertToFullAddr(addr)
return gonet.ListenTCP(net.stack, fa, pn)
}
func (net *Net) ListenTCP(addr *net.TCPAddr) (*gonet.TCPListener, error) {
if addr == nil {
panic("todo: deal with auto addr semantics for nil addr")
return net.ListenTCPAddrPort(netip.AddrPort{})
}
fa, pn := convertToFullAddr(addr.IP, addr.Port)
return gonet.ListenTCP(net.stack, fa, pn)
return net.ListenTCPAddrPort(netip.AddrPortFrom(netip.AddrFromSlice(addr.IP), uint16(addr.Port)))
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
func (net *Net) DialUDPAddrPort(laddr, raddr netip.AddrPort) (*gonet.UDPConn, error) {
var lfa, rfa *tcpip.FullAddress
var pn tcpip.NetworkProtocolNumber
if laddr != nil {
if laddr.IsValid() || laddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(laddr.IP, laddr.Port)
addr, pn = convertToFullAddr(laddr)
lfa = &addr
}
if raddr != nil {
if raddr.IsValid() || raddr.Port() > 0 {
var addr tcpip.FullAddress
addr, pn = convertToFullAddr(raddr.IP, raddr.Port)
addr, pn = convertToFullAddr(raddr)
rfa = &addr
}
return gonet.DialUDP(net.stack, lfa, rfa, pn)
}
func (net *Net) DialUDP(laddr, raddr *net.UDPAddr) (*gonet.UDPConn, error) {
var la, ra netip.AddrPort
if laddr != nil {
la = netip.AddrPortFrom(netip.AddrFromSlice(laddr.IP), uint16(laddr.Port))
}
if raddr != nil {
ra = netip.AddrPortFrom(netip.AddrFromSlice(raddr.IP), uint16(raddr.Port))
}
return net.DialUDPAddrPort(la, ra)
}
var (
errNoSuchHost = errors.New("no such host")
errLameReferral = errors.New("lame referral")
@@ -433,7 +453,7 @@ func dnsStreamRoundTrip(c net.Conn, id uint16, query dnsmessage.Question, b []by
return p, h, nil
}
func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
func (tnet *Net) exchange(ctx context.Context, server netip.Addr, q dnsmessage.Question, timeout time.Duration) (dnsmessage.Parser, dnsmessage.Header, error) {
q.Class = dnsmessage.ClassINET
id, udpReq, tcpReq, err := newRequest(q)
if err != nil {
@@ -447,9 +467,9 @@ func (tnet *Net) exchange(ctx context.Context, server net.IP, q dnsmessage.Quest
var c net.Conn
var err error
if useUDP {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: server, Port: 53})
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, netip.AddrPortFrom(server, 53))
} else {
c, err = tnet.DialContextTCP(ctx, &net.TCPAddr{IP: server, Port: 53})
c, err = tnet.DialContextTCPAddrPort(ctx, netip.AddrPortFrom(server, 53))
}
if err != nil {
@@ -600,8 +620,8 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
zlen = zidx
}
}
if ip := net.ParseIP(host[:zlen]); ip != nil {
return []string{host[:zlen]}, nil
if ip, err := netip.ParseAddr(host[:zlen]); err == nil {
return []string{ip.String()}, nil
}
if !isDomainName(host) {
@@ -612,7 +632,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
server string
error
}
var addrsV4, addrsV6 []net.IP
var addrsV4, addrsV6 []netip.Addr
lanes := 0
if tnet.hasV4 {
lanes++
@@ -667,7 +687,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
addrsV4 = append(addrsV4, net.IP(a.A[:]))
addrsV4 = append(addrsV4, netip.AddrFrom4(a.A))
case dnsmessage.TypeAAAA:
aaaa, err := result.p.AAAAResource()
@@ -679,7 +699,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
break loop
}
addrsV6 = append(addrsV6, net.IP(aaaa.AAAA[:]))
addrsV6 = append(addrsV6, netip.AddrFrom16(aaaa.AAAA))
default:
if err := result.p.SkipAnswer(); err != nil {
@@ -695,7 +715,7 @@ func (tnet *Net) LookupContextHost(ctx context.Context, host string) ([]string,
}
}
// We don't do RFC6724. Instead just put V6 addresess first if an IPv6 address is enabled
var addrs []net.IP
var addrs []netip.Addr
if tnet.hasV6 {
addrs = append(addrsV6, addrsV4...)
} else {
@@ -764,12 +784,11 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
if err != nil {
return nil, &net.OpError{Op: "dial", Err: err}
}
var addrs []net.IP
var addrs []netip.AddrPort
for _, addr := range allAddr {
if strings.IndexByte(addr, ':') != -1 && acceptV6 {
addrs = append(addrs, net.ParseIP(addr))
} else if strings.IndexByte(addr, '.') != -1 && acceptV4 {
addrs = append(addrs, net.ParseIP(addr))
ip, err := netip.ParseAddr(addr)
if err == nil && ((ip.Is4() && acceptV4) || (ip.Is6() && acceptV6)) {
addrs = append(addrs, netip.AddrPortFrom(ip, uint16(port)))
}
}
if len(addrs) == 0 && len(allAddr) != 0 {
@@ -808,9 +827,9 @@ func (tnet *Net) DialContext(ctx context.Context, network, address string) (net.
var c net.Conn
if useUDP {
c, err = tnet.DialUDP(nil, &net.UDPAddr{IP: addr, Port: port})
c, err = tnet.DialUDPAddrPort(netip.AddrPort{}, addr)
} else {
c, err = tnet.DialContextTCP(dialCtx, &net.TCPAddr{IP: addr, Port: port})
c, err = tnet.DialContextTCPAddrPort(dialCtx, addr)
}
if err == nil {
return c, nil