conn, device, tun: implement vectorized I/O on Linux

Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.

Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.

Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited
2023-03-02 15:08:28 -08:00
committed by Jason A. Donenfeld
parent 3bb8fec7e4
commit 9e2f386022
24 changed files with 1877 additions and 794 deletions

View File

@@ -6,32 +6,91 @@
package conn
import (
"context"
"errors"
"net"
"net/netip"
"strconv"
"sync"
"syscall"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
// StdNetBind is meant to be a temporary solution on platforms for which
// the sticky socket / source caching behavior has not yet been implemented.
// It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform.
var (
_ Bind = (*StdNetBind)(nil)
)
// StdNetBind implements Bind for all platforms except Windows.
type StdNetBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn
ipv6 *net.UDPConn
blackhole4 bool
blackhole6 bool
ipv4PC *ipv4.PacketConn
ipv6PC *ipv6.PacketConn
batchSize int
udpAddrPool sync.Pool
ipv4MsgsPool sync.Pool
ipv6MsgsPool sync.Pool
}
func NewStdNetBind() Bind { return &StdNetBind{} }
func NewStdNetBind() Bind { return NewStdNetBindBatch(DefaultBatchSize) }
type StdNetEndpoint netip.AddrPort
func NewStdNetBindBatch(maxBatchSize int) Bind {
if maxBatchSize == 0 {
maxBatchSize = DefaultBatchSize
}
return &StdNetBind{
batchSize: maxBatchSize,
udpAddrPool: sync.Pool{
New: func() any {
return &net.UDPAddr{
IP: make([]byte, 16),
}
},
},
ipv4MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv4.Message, maxBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
ipv6MsgsPool: sync.Pool{
New: func() any {
msgs := make([]ipv6.Message, maxBatchSize)
for i := range msgs {
msgs[i].Buffers = make(net.Buffers, 1)
msgs[i].OOB = make([]byte, srcControlSize)
}
return &msgs
},
},
}
}
type StdNetEndpoint struct {
// AddrPort is the endpoint destination.
netip.AddrPort
// src is the current sticky source address and interface index, if supported.
src struct {
netip.Addr
ifidx int32
}
}
var (
_ Bind = (*StdNetBind)(nil)
_ Endpoint = StdNetEndpoint{}
_ Endpoint = &StdNetEndpoint{}
)
func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
@@ -39,31 +98,38 @@ func (*StdNetBind) ParseEndpoint(s string) (Endpoint, error) {
return asEndpoint(e), err
}
func (StdNetEndpoint) ClearSrc() {}
func (e StdNetEndpoint) DstIP() netip.Addr {
return (netip.AddrPort)(e).Addr()
func (e *StdNetEndpoint) ClearSrc() {
e.src.ifidx = 0
e.src.Addr = netip.Addr{}
}
func (e StdNetEndpoint) SrcIP() netip.Addr {
return netip.Addr{} // not supported
func (e *StdNetEndpoint) DstIP() netip.Addr {
return e.AddrPort.Addr()
}
func (e StdNetEndpoint) DstToBytes() []byte {
b, _ := (netip.AddrPort)(e).MarshalBinary()
func (e *StdNetEndpoint) SrcIP() netip.Addr {
return e.src.Addr
}
func (e *StdNetEndpoint) SrcIfidx() int32 {
return e.src.ifidx
}
func (e *StdNetEndpoint) DstToBytes() []byte {
b, _ := e.AddrPort.MarshalBinary()
return b
}
func (e StdNetEndpoint) DstToString() string {
return (netip.AddrPort)(e).String()
func (e *StdNetEndpoint) DstToString() string {
return e.AddrPort.String()
}
func (e StdNetEndpoint) SrcToString() string {
return ""
func (e *StdNetEndpoint) SrcToString() string {
return e.src.Addr.String()
}
func listenNet(network string, port int) (*net.UDPConn, int, error) {
conn, err := net.ListenUDP(network, &net.UDPAddr{Port: port})
conn, err := listenConfig().ListenPacket(context.Background(), network, ":"+strconv.Itoa(port))
if err != nil {
return nil, 0, err
}
@@ -77,17 +143,17 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
if err != nil {
return nil, 0, err
}
return conn, uaddr.Port, nil
return conn.(*net.UDPConn), uaddr.Port, nil
}
func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
func (s *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
s.mu.Lock()
defer s.mu.Unlock()
var err error
var tries int
if bind.ipv4 != nil || bind.ipv6 != nil {
if s.ipv4 != nil || s.ipv6 != nil {
return nil, 0, ErrBindAlreadyOpen
}
@@ -95,104 +161,121 @@ func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
// If uport is 0, we can retry on failure.
again:
port := int(uport)
var ipv4, ipv6 *net.UDPConn
var v4conn, v6conn *net.UDPConn
ipv4, port, err = listenNet("udp4", port)
v4conn, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
return nil, 0, err
}
// Listen on the same port as we're using for ipv4.
ipv6, port, err = listenNet("udp6", port)
v6conn, port, err = listenNet("udp6", port)
if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
ipv4.Close()
v4conn.Close()
tries++
goto again
}
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
ipv4.Close()
v4conn.Close()
return nil, 0, err
}
var fns []ReceiveFunc
if ipv4 != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
if v4conn != nil {
fns = append(fns, s.receiveIPv4)
s.ipv4 = v4conn
}
if ipv6 != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
if v6conn != nil {
fns = append(fns, s.receiveIPv6)
s.ipv6 = v6conn
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
s.ipv4PC = ipv4.NewPacketConn(s.ipv4)
s.ipv6PC = ipv6.NewPacketConn(s.ipv6)
return fns, uint16(port), nil
}
func (bind *StdNetBind) BatchSize() int {
return 1
func (s *StdNetBind) receiveIPv4(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
defer s.ipv4MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv4PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
func (s *StdNetBind) receiveIPv6(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
defer s.ipv6MsgsPool.Put(msgs)
for i := range buffs {
(*msgs)[i].Buffers[0] = buffs[i]
}
numMsgs, err := s.ipv6PC.ReadBatch(*msgs, 0)
if err != nil {
return 0, err
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
ep := asEndpoint(addrPort)
getSrcFromControl(msg.OOB, ep)
eps[i] = ep
}
return numMsgs, nil
}
func (s *StdNetBind) BatchSize() int {
return s.batchSize
}
func (s *StdNetBind) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
var err1, err2 error
if bind.ipv4 != nil {
err1 = bind.ipv4.Close()
bind.ipv4 = nil
if s.ipv4 != nil {
err1 = s.ipv4.Close()
s.ipv4 = nil
}
if bind.ipv6 != nil {
err2 = bind.ipv6.Close()
bind.ipv6 = nil
if s.ipv6 != nil {
err2 = s.ipv6.Close()
s.ipv6 = nil
}
bind.blackhole4 = false
bind.blackhole6 = false
s.blackhole4 = false
s.blackhole6 = false
if err1 != nil {
return err1
}
return err2
}
func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
if err == nil {
sizes[0] = size
eps[0] = asEndpoint(endpoint)
return 1, nil
}
return 0, err
func (s *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
s.mu.Lock()
blackhole := s.blackhole4
conn := s.ipv4
is6 := false
if endpoint.DstIP().Is6() {
blackhole = s.blackhole6
conn = s.ipv6
is6 = true
}
}
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buffs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
size, endpoint, err := conn.ReadFromUDPAddrPort(buffs[0])
if err == nil {
sizes[0] = size
eps[0] = asEndpoint(endpoint)
return 1, nil
}
return 0, err
}
}
func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
var err error
nend, ok := endpoint.(StdNetEndpoint)
if !ok {
return ErrWrongEndpointType
}
addrPort := netip.AddrPort(nend)
bind.mu.Lock()
blackhole := bind.blackhole4
conn := bind.ipv4
if addrPort.Addr().Is6() {
blackhole = bind.blackhole6
conn = bind.ipv6
}
bind.mu.Unlock()
s.mu.Unlock()
if blackhole {
return nil
@@ -200,13 +283,69 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
if conn == nil {
return syscall.EAFNOSUPPORT
}
for _, buff := range buffs {
_, err = conn.WriteToUDPAddrPort(buff, addrPort)
if err != nil {
return err
}
if is6 {
return s.send6(s.ipv6PC, endpoint, buffs)
} else {
return s.send4(s.ipv4PC, endpoint, buffs)
}
return nil
}
func (s *StdNetBind) send4(conn *ipv4.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as4 := ep.DstIP().As4()
copy(ua.IP, as4[:])
ua.IP = ua.IP[:4]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv4MsgsPool.Get().(*[]ipv4.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv4MsgsPool.Put(msgs)
return err
}
func (s *StdNetBind) send6(conn *ipv6.PacketConn, ep Endpoint, buffs [][]byte) error {
ua := s.udpAddrPool.Get().(*net.UDPAddr)
as16 := ep.DstIP().As16()
copy(ua.IP, as16[:])
ua.IP = ua.IP[:16]
ua.Port = int(ep.(*StdNetEndpoint).Port())
msgs := s.ipv6MsgsPool.Get().(*[]ipv6.Message)
for i, buff := range buffs {
(*msgs)[i].Buffers[0] = buff
(*msgs)[i].Addr = ua
setSrcControl(&(*msgs)[i].OOB, ep.(*StdNetEndpoint))
}
var (
n int
err error
start int
)
for {
n, err = conn.WriteBatch((*msgs)[start:len(buffs)], 0)
if err != nil || n == len((*msgs)[start:len(buffs)]) {
break
}
start += n
}
s.udpAddrPool.Put(ua)
s.ipv6MsgsPool.Put(msgs)
return err
}
// endpointPool contains a re-usable set of mapping from netip.AddrPort to Endpoint.
@@ -214,17 +353,17 @@ func (bind *StdNetBind) Send(buffs [][]byte, endpoint Endpoint) error {
// but Endpoints are immutable, so we can re-use them.
var endpointPool = sync.Pool{
New: func() any {
return make(map[netip.AddrPort]Endpoint)
return make(map[netip.AddrPort]*StdNetEndpoint)
},
}
// asEndpoint returns an Endpoint containing ap.
func asEndpoint(ap netip.AddrPort) Endpoint {
m := endpointPool.Get().(map[netip.AddrPort]Endpoint)
func asEndpoint(ap netip.AddrPort) *StdNetEndpoint {
m := endpointPool.Get().(map[netip.AddrPort]*StdNetEndpoint)
defer endpointPool.Put(m)
e, ok := m[ap]
if !ok {
e = Endpoint(StdNetEndpoint(ap))
e = &StdNetEndpoint{AddrPort: ap}
m[ap] = e
}
return e