conn, device, tun: implement vectorized I/O on Linux

Implement TCP offloading via TSO and GRO for the Linux tun.Device, which
is made possible by virtio extensions in the kernel's TUN driver.

Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind.
conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All
platforms now fall under conn.StdNetBind, except for Windows, which
remains in conn.WinRingBind, which still needs to be adjusted to handle
multiple packets.

Also refactor sticky sockets support to eventually be applicable on
platforms other than just Linux. However Linux remains the sole platform
that fully implements it for now.

Co-authored-by: James Tucker <james@tailscale.com>
Signed-off-by: James Tucker <james@tailscale.com>
Signed-off-by: Jordan Whited <jordan@tailscale.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jordan Whited
2023-03-02 15:08:28 -08:00
committed by Jason A. Donenfeld
parent 3bb8fec7e4
commit 9e2f386022
24 changed files with 1877 additions and 794 deletions

View File

@@ -17,9 +17,8 @@ import (
"time"
"unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix"
"golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/rwcancel"
)
@@ -33,17 +32,25 @@ type NativeTun struct {
index int32 // if index
errors chan error // async error handling
events chan Event // device related events
nopi bool // the device was passed IFF_NO_PI
netlinkSock int
netlinkCancel *rwcancel.RWCancel
hackListenerClosed sync.Mutex
statusListenersShutdown chan struct{}
batchSize int
vnetHdr bool
closeOnce sync.Once
nameOnce sync.Once // guards calling initNameCache, which sets following fields
nameCache string // name of interface
nameErr error
readOpMu sync.Mutex // readOpMu guards readBuff
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
toWrite []int
tcp4GROTable, tcp6GROTable *tcpGROTable
}
func (tun *NativeTun) File() *os.File {
@@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) {
return unix.ByteSliceToString(ifr[:]), nil
}
func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
var buf []byte
if tun.nopi {
buf = buffs[0][offset:]
func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
tun.writeOpMu.Lock()
defer func() {
tun.tcp4GROTable.reset()
tun.tcp6GROTable.reset()
tun.writeOpMu.Unlock()
}()
var (
errs []error
total int
)
tun.toWrite = tun.toWrite[:0]
if tun.vnetHdr {
err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
if err != nil {
return 0, err
}
offset -= virtioNetHdrLen
} else {
// reserve space for header
buf = buffs[0][offset-4:]
// add packet information header
buf[0] = 0x00
buf[1] = 0x00
if buf[4]>>4 == ipv6.Version {
buf[2] = 0x86
buf[3] = 0xdd
} else {
buf[2] = 0x08
buf[3] = 0x00
for i := range buffs {
tun.toWrite = append(tun.toWrite, i)
}
}
_, err = tun.tunFile.Write(buf)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
} else if err == nil {
n = 1
for _, buffsI := range tun.toWrite {
n, err := tun.tunFile.Write(buffs[buffsI][offset:])
if errors.Is(err, syscall.EBADFD) {
return total, os.ErrClosed
}
if err != nil {
errs = append(errs, err)
} else {
total += n
}
}
return n, err
return total, ErrorBatch(errs)
}
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
select {
case err = <-tun.errors:
default:
if tun.nopi {
sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
if err == nil {
n = 1
}
} else {
buff := buffs[0][offset-4:]
sizes[0], err = tun.tunFile.Read(buff[:])
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
} else if err == nil {
n = 1
}
if sizes[0] < 4 {
sizes[0] = 0
} else {
sizes[0] -= 4
// handleVirtioRead splits in into buffs, leaving offset bytes at the front of
// each buffer. It mutates sizes to reflect the size of each element of buffs,
// and returns the number of packets read.
func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) {
var hdr virtioNetHdr
err := hdr.decode(in)
if err != nil {
return 0, err
}
in = in[virtioNetHdrLen:]
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
// This means CHECKSUM_PARTIAL in skb context. We are responsible
// for computing the checksum starting at hdr.csumStart and placing
// at hdr.csumOffset.
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
if err != nil {
return 0, err
}
}
if len(in) > len(buffs[0][offset:]) {
return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:]))
}
n := copy(buffs[0][offset:], in)
sizes[0] = n
return 1, nil
}
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
}
ipVersion := in[0] >> 4
switch ipVersion {
case 4:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
case 6:
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
}
default:
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
}
if len(in) <= int(hdr.csumStart+12) {
return 0, errors.New("packet is too short")
}
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
// of the entire first packet when the kernel is handling it as part of a
// FORWARD path. Instead, parse the TCP header length and add it onto
// csumStart, which is synonymous for IP header length.
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
if tcpHLen < 20 || tcpHLen > 60 {
// A TCP header must be between 20 and 60 bytes in length.
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
}
hdr.hdrLen = hdr.csumStart + tcpHLen
if len(in) < int(hdr.hdrLen) {
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
}
if hdr.hdrLen < hdr.csumStart {
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
}
cSumAt := int(hdr.csumStart + hdr.csumOffset)
if cSumAt+1 >= len(in) {
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
}
return tcpTSO(in, hdr, buffs, sizes, offset)
}
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
tun.readOpMu.Lock()
defer tun.readOpMu.Unlock()
select {
case err := <-tun.errors:
return 0, err
default:
readInto := buffs[0][offset:]
if tun.vnetHdr {
readInto = tun.readBuff[:]
}
n, err := tun.tunFile.Read(readInto)
if errors.Is(err, syscall.EBADFD) {
err = os.ErrClosed
}
if err != nil {
return 0, err
}
if tun.vnetHdr {
return handleVirtioRead(readInto[:n], buffs, sizes, offset)
} else {
sizes[0] = n
return 1, nil
}
}
return
}
func (tun *NativeTun) Events() <-chan Event {
@@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error {
}
func (tun *NativeTun) BatchSize() int {
return 1
return tun.batchSize
}
const (
// TODO: support TSO with ECN bits
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
)
func (tun *NativeTun) initFromFlags(name string) error {
sc, err := tun.tunFile.SyscallConn()
if err != nil {
return err
}
if e := sc.Control(func(fd uintptr) {
var (
ifr *unix.Ifreq
)
ifr, err = unix.NewIfreq(name)
if err != nil {
return
}
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
if err != nil {
return
}
got := ifr.Uint16()
if got&unix.IFF_VNET_HDR != 0 {
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
if err != nil {
return
}
tun.vnetHdr = true
tun.batchSize = conn.DefaultBatchSize
} else {
tun.batchSize = 1
}
}); e != nil {
return e
}
return err
}
// CreateTUN creates a Device with the provided name and MTU.
func CreateTUN(name string, mtu int) (Device, error) {
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
if err != nil {
@@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return nil, err
}
var ifr [ifReqSize]byte
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
nameBytes := []byte(name)
if len(nameBytes) >= unix.IFNAMSIZ {
unix.Close(nfd)
return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
ifr, err := unix.NewIfreq(name)
if err != nil {
return nil, err
}
copy(ifr[:], nameBytes)
*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(nfd),
uintptr(unix.TUNSETIFF),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 {
unix.Close(nfd)
return nil, errno
// IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
// where a null write will return EINVAL indicating the TUN is up.
ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
if err != nil {
return nil, err
}
err = unix.SetNonblock(nfd, true)
@@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
return CreateTUNFromFile(fd, mtu)
}
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
statusListenersShutdown: make(chan struct{}),
nopi: false,
tcp4GROTable: newTCPGROTable(),
tcp6GROTable: newTCPGROTable(),
toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
@@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return nil, err
}
// start event listener
err = tun.initFromFlags(name)
if err != nil {
return nil, err
}
// start event listener
tun.index, err = getIFIndex(name)
if err != nil {
return nil, err
@@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
return tun, nil
}
// CreateUnmonitoredTUNFromFD creates a Device from the provided file
// descriptor.
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
err := unix.SetNonblock(fd, true)
if err != nil {
@@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
}
file := os.NewFile(uintptr(fd), "/dev/tun")
tun := &NativeTun{
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
nopi: true,
tunFile: file,
events: make(chan Event, 5),
errors: make(chan error, 5),
tcp4GROTable: newTCPGROTable(),
tcp6GROTable: newTCPGROTable(),
toWrite: make([]int, 0, conn.DefaultBatchSize),
}
name, err := tun.Name()
if err != nil {
return nil, "", err
}
return tun, name, nil
err = tun.initFromFlags(name)
if err != nil {
return nil, "", err
}
return tun, name, err
}