conn, device, tun: implement vectorized I/O on Linux
Implement TCP offloading via TSO and GRO for the Linux tun.Device, which is made possible by virtio extensions in the kernel's TUN driver. Delete conn.LinuxSocketEndpoint in favor of a collapsed conn.StdNetBind. conn.StdNetBind makes use of recvmmsg() and sendmmsg() on Linux. All platforms now fall under conn.StdNetBind, except for Windows, which remains in conn.WinRingBind, which still needs to be adjusted to handle multiple packets. Also refactor sticky sockets support to eventually be applicable on platforms other than just Linux. However Linux remains the sole platform that fully implements it for now. Co-authored-by: James Tucker <james@tailscale.com> Signed-off-by: James Tucker <james@tailscale.com> Signed-off-by: Jordan Whited <jordan@tailscale.com> Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
committed by
Jason A. Donenfeld
parent
3bb8fec7e4
commit
9e2f386022
283
tun/tun_linux.go
283
tun/tun_linux.go
@@ -17,9 +17,8 @@ import (
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/net/ipv6"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
@@ -33,17 +32,25 @@ type NativeTun struct {
|
||||
index int32 // if index
|
||||
errors chan error // async error handling
|
||||
events chan Event // device related events
|
||||
nopi bool // the device was passed IFF_NO_PI
|
||||
netlinkSock int
|
||||
netlinkCancel *rwcancel.RWCancel
|
||||
hackListenerClosed sync.Mutex
|
||||
statusListenersShutdown chan struct{}
|
||||
batchSize int
|
||||
vnetHdr bool
|
||||
|
||||
closeOnce sync.Once
|
||||
|
||||
nameOnce sync.Once // guards calling initNameCache, which sets following fields
|
||||
nameCache string // name of interface
|
||||
nameErr error
|
||||
|
||||
readOpMu sync.Mutex // readOpMu guards readBuff
|
||||
readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr
|
||||
|
||||
writeOpMu sync.Mutex // writeOpMu guards toWrite, tcp4GROTable, tcp6GROTable
|
||||
toWrite []int
|
||||
tcp4GROTable, tcp6GROTable *tcpGROTable
|
||||
}
|
||||
|
||||
func (tun *NativeTun) File() *os.File {
|
||||
@@ -323,60 +330,142 @@ func (tun *NativeTun) nameSlow() (string, error) {
|
||||
return unix.ByteSliceToString(ifr[:]), nil
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (n int, err error) {
|
||||
var buf []byte
|
||||
if tun.nopi {
|
||||
buf = buffs[0][offset:]
|
||||
func (tun *NativeTun) Write(buffs [][]byte, offset int) (int, error) {
|
||||
tun.writeOpMu.Lock()
|
||||
defer func() {
|
||||
tun.tcp4GROTable.reset()
|
||||
tun.tcp6GROTable.reset()
|
||||
tun.writeOpMu.Unlock()
|
||||
}()
|
||||
var (
|
||||
errs []error
|
||||
total int
|
||||
)
|
||||
tun.toWrite = tun.toWrite[:0]
|
||||
if tun.vnetHdr {
|
||||
err := handleGRO(buffs, offset, tun.tcp4GROTable, tun.tcp6GROTable, &tun.toWrite)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
offset -= virtioNetHdrLen
|
||||
} else {
|
||||
// reserve space for header
|
||||
buf = buffs[0][offset-4:]
|
||||
|
||||
// add packet information header
|
||||
buf[0] = 0x00
|
||||
buf[1] = 0x00
|
||||
if buf[4]>>4 == ipv6.Version {
|
||||
buf[2] = 0x86
|
||||
buf[3] = 0xdd
|
||||
} else {
|
||||
buf[2] = 0x08
|
||||
buf[3] = 0x00
|
||||
for i := range buffs {
|
||||
tun.toWrite = append(tun.toWrite, i)
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tun.tunFile.Write(buf)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
} else if err == nil {
|
||||
n = 1
|
||||
for _, buffsI := range tun.toWrite {
|
||||
n, err := tun.tunFile.Write(buffs[buffsI][offset:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
return total, os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
errs = append(errs, err)
|
||||
} else {
|
||||
total += n
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
return total, ErrorBatch(errs)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||
select {
|
||||
case err = <-tun.errors:
|
||||
default:
|
||||
if tun.nopi {
|
||||
sizes[0], err = tun.tunFile.Read(buffs[0][offset:])
|
||||
if err == nil {
|
||||
n = 1
|
||||
}
|
||||
} else {
|
||||
buff := buffs[0][offset-4:]
|
||||
sizes[0], err = tun.tunFile.Read(buff[:])
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
} else if err == nil {
|
||||
n = 1
|
||||
}
|
||||
if sizes[0] < 4 {
|
||||
sizes[0] = 0
|
||||
} else {
|
||||
sizes[0] -= 4
|
||||
// handleVirtioRead splits in into buffs, leaving offset bytes at the front of
|
||||
// each buffer. It mutates sizes to reflect the size of each element of buffs,
|
||||
// and returns the number of packets read.
|
||||
func handleVirtioRead(in []byte, buffs [][]byte, sizes []int, offset int) (int, error) {
|
||||
var hdr virtioNetHdr
|
||||
err := hdr.decode(in)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
in = in[virtioNetHdrLen:]
|
||||
if hdr.gsoType == unix.VIRTIO_NET_HDR_GSO_NONE {
|
||||
if hdr.flags&unix.VIRTIO_NET_HDR_F_NEEDS_CSUM != 0 {
|
||||
// This means CHECKSUM_PARTIAL in skb context. We are responsible
|
||||
// for computing the checksum starting at hdr.csumStart and placing
|
||||
// at hdr.csumOffset.
|
||||
err = gsoNoneChecksum(in, hdr.csumStart, hdr.csumOffset)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
if len(in) > len(buffs[0][offset:]) {
|
||||
return 0, fmt.Errorf("read len %d overflows buffs element len %d", len(in), len(buffs[0][offset:]))
|
||||
}
|
||||
n := copy(buffs[0][offset:], in)
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 && hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
return 0, fmt.Errorf("unsupported virtio GSO type: %d", hdr.gsoType)
|
||||
}
|
||||
|
||||
ipVersion := in[0] >> 4
|
||||
switch ipVersion {
|
||||
case 4:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV4 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
case 6:
|
||||
if hdr.gsoType != unix.VIRTIO_NET_HDR_GSO_TCPV6 {
|
||||
return 0, fmt.Errorf("ip header version: %d, GSO type: %d", ipVersion, hdr.gsoType)
|
||||
}
|
||||
default:
|
||||
return 0, fmt.Errorf("invalid ip header version: %d", ipVersion)
|
||||
}
|
||||
|
||||
if len(in) <= int(hdr.csumStart+12) {
|
||||
return 0, errors.New("packet is too short")
|
||||
}
|
||||
// Don't trust hdr.hdrLen from the kernel as it can be equal to the length
|
||||
// of the entire first packet when the kernel is handling it as part of a
|
||||
// FORWARD path. Instead, parse the TCP header length and add it onto
|
||||
// csumStart, which is synonymous for IP header length.
|
||||
tcpHLen := uint16(in[hdr.csumStart+12] >> 4 * 4)
|
||||
if tcpHLen < 20 || tcpHLen > 60 {
|
||||
// A TCP header must be between 20 and 60 bytes in length.
|
||||
return 0, fmt.Errorf("tcp header len is invalid: %d", tcpHLen)
|
||||
}
|
||||
hdr.hdrLen = hdr.csumStart + tcpHLen
|
||||
|
||||
if len(in) < int(hdr.hdrLen) {
|
||||
return 0, fmt.Errorf("length of packet (%d) < virtioNetHdr.hdrLen (%d)", len(in), hdr.hdrLen)
|
||||
}
|
||||
|
||||
if hdr.hdrLen < hdr.csumStart {
|
||||
return 0, fmt.Errorf("virtioNetHdr.hdrLen (%d) < virtioNetHdr.csumStart (%d)", hdr.hdrLen, hdr.csumStart)
|
||||
}
|
||||
cSumAt := int(hdr.csumStart + hdr.csumOffset)
|
||||
if cSumAt+1 >= len(in) {
|
||||
return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in))
|
||||
}
|
||||
|
||||
return tcpTSO(in, hdr, buffs, sizes, offset)
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Read(buffs [][]byte, sizes []int, offset int) (int, error) {
|
||||
tun.readOpMu.Lock()
|
||||
defer tun.readOpMu.Unlock()
|
||||
select {
|
||||
case err := <-tun.errors:
|
||||
return 0, err
|
||||
default:
|
||||
readInto := buffs[0][offset:]
|
||||
if tun.vnetHdr {
|
||||
readInto = tun.readBuff[:]
|
||||
}
|
||||
n, err := tun.tunFile.Read(readInto)
|
||||
if errors.Is(err, syscall.EBADFD) {
|
||||
err = os.ErrClosed
|
||||
}
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
if tun.vnetHdr {
|
||||
return handleVirtioRead(readInto[:n], buffs, sizes, offset)
|
||||
} else {
|
||||
sizes[0] = n
|
||||
return 1, nil
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (tun *NativeTun) Events() <-chan Event {
|
||||
@@ -403,9 +492,49 @@ func (tun *NativeTun) Close() error {
|
||||
}
|
||||
|
||||
func (tun *NativeTun) BatchSize() int {
|
||||
return 1
|
||||
return tun.batchSize
|
||||
}
|
||||
|
||||
const (
|
||||
// TODO: support TSO with ECN bits
|
||||
tunOffloads = unix.TUN_F_CSUM | unix.TUN_F_TSO4 | unix.TUN_F_TSO6
|
||||
)
|
||||
|
||||
func (tun *NativeTun) initFromFlags(name string) error {
|
||||
sc, err := tun.tunFile.SyscallConn()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if e := sc.Control(func(fd uintptr) {
|
||||
var (
|
||||
ifr *unix.Ifreq
|
||||
)
|
||||
ifr, err = unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = unix.IoctlIfreq(int(fd), unix.TUNGETIFF, ifr)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
got := ifr.Uint16()
|
||||
if got&unix.IFF_VNET_HDR != 0 {
|
||||
err = unix.IoctlSetInt(int(fd), unix.TUNSETOFFLOAD, tunOffloads)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
tun.vnetHdr = true
|
||||
tun.batchSize = conn.DefaultBatchSize
|
||||
} else {
|
||||
tun.batchSize = 1
|
||||
}
|
||||
}); e != nil {
|
||||
return e
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// CreateTUN creates a Device with the provided name and MTU.
|
||||
func CreateTUN(name string, mtu int) (Device, error) {
|
||||
nfd, err := unix.Open(cloneDevicePath, unix.O_RDWR|unix.O_CLOEXEC, 0)
|
||||
if err != nil {
|
||||
@@ -415,25 +544,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ifr [ifReqSize]byte
|
||||
var flags uint16 = unix.IFF_TUN // | unix.IFF_NO_PI (disabled for TUN status hack)
|
||||
nameBytes := []byte(name)
|
||||
if len(nameBytes) >= unix.IFNAMSIZ {
|
||||
unix.Close(nfd)
|
||||
return nil, fmt.Errorf("interface name too long: %w", unix.ENAMETOOLONG)
|
||||
ifr, err := unix.NewIfreq(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
copy(ifr[:], nameBytes)
|
||||
*(*uint16)(unsafe.Pointer(&ifr[unix.IFNAMSIZ])) = flags
|
||||
|
||||
_, _, errno := unix.Syscall(
|
||||
unix.SYS_IOCTL,
|
||||
uintptr(nfd),
|
||||
uintptr(unix.TUNSETIFF),
|
||||
uintptr(unsafe.Pointer(&ifr[0])),
|
||||
)
|
||||
if errno != 0 {
|
||||
unix.Close(nfd)
|
||||
return nil, errno
|
||||
// IFF_VNET_HDR enables the "tun status hack" via routineHackListener()
|
||||
// where a null write will return EINVAL indicating the TUN is up.
|
||||
ifr.SetUint16(unix.IFF_TUN | unix.IFF_NO_PI | unix.IFF_VNET_HDR)
|
||||
err = unix.IoctlIfreq(nfd, unix.TUNSETIFF, ifr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = unix.SetNonblock(nfd, true)
|
||||
@@ -448,13 +568,16 @@ func CreateTUN(name string, mtu int) (Device, error) {
|
||||
return CreateTUNFromFile(fd, mtu)
|
||||
}
|
||||
|
||||
// CreateTUNFromFile creates a Device from an os.File with the provided MTU.
|
||||
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
statusListenersShutdown: make(chan struct{}),
|
||||
nopi: false,
|
||||
tcp4GROTable: newTCPGROTable(),
|
||||
tcp6GROTable: newTCPGROTable(),
|
||||
toWrite: make([]int, 0, conn.DefaultBatchSize),
|
||||
}
|
||||
|
||||
name, err := tun.Name()
|
||||
@@ -462,8 +585,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start event listener
|
||||
err = tun.initFromFlags(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// start event listener
|
||||
tun.index, err = getIFIndex(name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -492,6 +619,8 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
|
||||
return tun, nil
|
||||
}
|
||||
|
||||
// CreateUnmonitoredTUNFromFD creates a Device from the provided file
|
||||
// descriptor.
|
||||
func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||
err := unix.SetNonblock(fd, true)
|
||||
if err != nil {
|
||||
@@ -499,14 +628,20 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) {
|
||||
}
|
||||
file := os.NewFile(uintptr(fd), "/dev/tun")
|
||||
tun := &NativeTun{
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
nopi: true,
|
||||
tunFile: file,
|
||||
events: make(chan Event, 5),
|
||||
errors: make(chan error, 5),
|
||||
tcp4GROTable: newTCPGROTable(),
|
||||
tcp6GROTable: newTCPGROTable(),
|
||||
toWrite: make([]int, 0, conn.DefaultBatchSize),
|
||||
}
|
||||
name, err := tun.Name()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return tun, name, nil
|
||||
err = tun.initFromFlags(name)
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
return tun, name, err
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user