20 Commits

Author SHA1 Message Date
Jason A. Donenfeld
c9db4b7aaa version: bump snapshot
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-24 13:07:27 -04:00
Jason A. Donenfeld
3625f8d284 tun: freebsd: avoid OOB writes
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:10:23 -06:00
Jason A. Donenfeld
0687dc06c8 tun: freebsd: become controlling process when reopening tun FD
When we pass the TUN FD to the child, we have to call TUNSIFPID;
otherwise when we close the device, we get a splat in dmesg.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 15:02:44 -06:00
Jason A. Donenfeld
71aefa374d tun: freebsd: restructure and cleanup
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 14:54:59 -06:00
Jason A. Donenfeld
3d3e30beb8 tun: freebsd: remove horrific hack for getting tunnel name
As of FreeBSD 12.1, there's TUNGIFNAME.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-19 12:03:16 -06:00
Jason A. Donenfeld
b0e5b19969 tun: freebsd: set IFF_MULTICAST for routing daemons
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-18 20:09:04 -06:00
Jason A. Donenfeld
3988821442 main: print kernel warning on OpenBSD and FreeBSD too
More kernels!

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-15 23:35:45 -06:00
Jason A. Donenfeld
c7cd2c9eab device: don't defer unlocking from loop
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 16:19:35 -06:00
Jason A. Donenfeld
54dbe2471f conn: reconstruct v4 vs v6 receive function based on symtab
This is kind of gross but it's better than the alternatives.

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 15:35:32 -06:00
Kristupas Antanavičius
d2fd0c0cc0 device: allocate new buffer in receive death spiral
Note: this bug is "hidden" by avoiding "death spiral" code path by
6228659 ("device: handle broader range of errors in RoutineReceiveIncoming").

If the code reached "death spiral" mechanism, there would be multiple
double frees happening. This results in a deadlock on iOS, because the
pools are fixed size and goroutine might stop until somebody makes
space in the pool.

This was almost 100% repro on the new ARM Macbooks:

- Build with 'ios' tag for Mac. This will enable bounded pools.
- Somehow call device.IpcSet at least couple of times (update config)
- device.BindUpdate() would be triggered
- RoutineReceiveIncoming would enter "death spiral".
- RoutineReceiveIncoming would stall on double free (pool is already
  full)
- The stuck routine would deadlock 'device.closeBindLocked()' function
  on line 'netc.stopping.Wait()'

Signed-off-by: Kristupas Antanavičius <kristupas.antanavicius@nordsec.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-12 11:14:53 -06:00
Jason A. Donenfeld
5f6bbe4ae8 conn: windows: reset ring to starting position after free
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 18:09:41 -06:00
Jason A. Donenfeld
75526d6071 conn: windows: compare head and tail properly
By not comparing these with the modulo, the ring became nearly never
full, resulting in completion queue buffers filling up prematurely.

Reported-by: Joshua Sjoding <joshua.sjoding@scjalliance.com>
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Jason A. Donenfeld
fbf97502cf winrio: test that IOCP-based RIO is supported
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-09 14:26:08 -06:00
Josh Bleecher Snyder
10533c3e73 all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which
to receive packets (an IPv4 source and an IPv6 source),
allow the conn.Bind to specify a set of sources.

Beneficial consequences:

* If there's no IPv6 support on a system,
  conn.Bind.Open can choose not to return a receive function for it,
  which is simpler than tracking that state in the bind.
  This simplification removes existing data races from both
  conn.StdNetBind and bindtest.ChannelBind.
* If there are more than two sources on a system,
  the conn.Bind no longer needs to add a separate muxing layer.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-04-02 11:07:08 -06:00
Jason A. Donenfeld
8ed83e0427 conn: winrio: pass key parameter into struct
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
2021-04-02 10:36:41 -06:00
Josh Bleecher Snyder
6228659a91 device: handle broader range of errors in RoutineReceiveIncoming
RoutineReceiveIncoming exits immediately on net.ErrClosed,
but not on other errors. However, for errors that are known
to be permanent, such as syscall.EAFNOSUPPORT,
we may as well exit immediately instead of retrying.

This considerably speeds up the package device tests right now,
because the Bind sometimes (incorrectly) returns syscall.EAFNOSUPPORT
instead of net.ErrClosed.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:41:43 -07:00
Josh Bleecher Snyder
517f0703f5 conn: document retry loop in StdNetBind.Open
It's not obvious on a first read what the loop is doing.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
204140016a conn: use local ipvN vars in StdNetBind.Open
This makes it clearer that they are fresh on each attempt,
and avoids the bookkeeping required to clearing them on failure.

Also, remove an unnecessary err != nil.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:38 -07:00
Josh Bleecher Snyder
822f5a6d70 conn: unify code in StdNetBind.Send
The sending code is identical for ipv4 and ipv6;
select the conn, then use it.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:09:32 -07:00
Josh Bleecher Snyder
02e419ed8a device: rename unsafeCloseBind to closeBindLocked
And document a bit.
This name is more idiomatic.

Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
2021-03-30 12:07:12 -07:00
12 changed files with 345 additions and 428 deletions

View File

@@ -55,10 +55,11 @@ func (endpoint *LinuxSocketEndpoint) dst6() *unix.SockaddrInet6 {
// LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux. // LinuxSocketBind uses sendmsg and recvmsg to implement a full bind with sticky sockets on Linux.
type LinuxSocketBind struct { type LinuxSocketBind struct {
// mu guards sock4 and sock6 and the associated fds.
// As long as someone holds mu (read or write), the associated fds are valid.
mu sync.RWMutex
sock4 int sock4 int
sock6 int sock6 int
lastMark uint32
closing sync.RWMutex
} }
func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} } func NewLinuxSocketBind() Bind { return &LinuxSocketBind{sock4: -1, sock6: -1} }
@@ -102,54 +103,67 @@ func (*LinuxSocketBind) ParseEndpoint(s string) (Endpoint, error) {
return nil, errors.New("invalid IP address") return nil, errors.New("invalid IP address")
} }
func (bind *LinuxSocketBind) Open(port uint16) (uint16, error) { func (bind *LinuxSocketBind) Open(port uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error var err error
var newPort uint16 var newPort uint16
var tries int var tries int
if bind.sock4 != -1 || bind.sock6 != -1 { if bind.sock4 != -1 || bind.sock6 != -1 {
return 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
originalPort := port originalPort := port
again: again:
port = originalPort port = originalPort
var sock4, sock6 int
// Attempt ipv6 bind, update port if successful. // Attempt ipv6 bind, update port if successful.
bind.sock6, newPort, err = create6(port) sock6, newPort, err = create6(port)
if err != nil { if err != nil {
if err != syscall.EAFNOSUPPORT { if !errors.Is(err, syscall.EAFNOSUPPORT) {
return 0, err return nil, 0, err
} }
} else { } else {
port = newPort port = newPort
} }
// Attempt ipv4 bind, update port if successful. // Attempt ipv4 bind, update port if successful.
bind.sock4, newPort, err = create4(port) sock4, newPort, err = create4(port)
if err != nil { if err != nil {
if originalPort == 0 && err == syscall.EADDRINUSE && tries < 100 { if originalPort == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
unix.Close(bind.sock6) unix.Close(sock6)
tries++ tries++
goto again goto again
} }
if err != syscall.EAFNOSUPPORT { if !errors.Is(err, syscall.EAFNOSUPPORT) {
unix.Close(bind.sock6) unix.Close(sock6)
return 0, err return nil, 0, err
} }
} else { } else {
port = newPort port = newPort
} }
if bind.sock4 == -1 && bind.sock6 == -1 { var fns []ReceiveFunc
return 0, syscall.EAFNOSUPPORT if sock4 != -1 {
fns = append(fns, bind.makeReceiveIPv4(sock4))
bind.sock4 = sock4
} }
return port, nil if sock6 != -1 {
fns = append(fns, bind.makeReceiveIPv6(sock6))
bind.sock6 = sock6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, port, nil
} }
func (bind *LinuxSocketBind) SetMark(value uint32) error { func (bind *LinuxSocketBind) SetMark(value uint32) error {
bind.closing.RLock() bind.mu.RLock()
defer bind.closing.RUnlock() defer bind.mu.RUnlock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
err := unix.SetsockoptInt( err := unix.SetsockoptInt(
@@ -177,21 +191,24 @@ func (bind *LinuxSocketBind) SetMark(value uint32) error {
} }
} }
bind.lastMark = value
return nil return nil
} }
func (bind *LinuxSocketBind) Close() error { func (bind *LinuxSocketBind) Close() error {
var err1, err2 error // Take a readlock to shut down the sockets...
bind.closing.RLock() bind.mu.RLock()
if bind.sock6 != -1 { if bind.sock6 != -1 {
unix.Shutdown(bind.sock6, unix.SHUT_RDWR) unix.Shutdown(bind.sock6, unix.SHUT_RDWR)
} }
if bind.sock4 != -1 { if bind.sock4 != -1 {
unix.Shutdown(bind.sock4, unix.SHUT_RDWR) unix.Shutdown(bind.sock4, unix.SHUT_RDWR)
} }
bind.closing.RUnlock() bind.mu.RUnlock()
bind.closing.Lock() // ...and a write lock to close the fd.
// This ensures that no one else is using the fd.
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error
if bind.sock6 != -1 { if bind.sock6 != -1 {
err1 = unix.Close(bind.sock6) err1 = unix.Close(bind.sock6)
bind.sock6 = -1 bind.sock6 = -1
@@ -200,7 +217,6 @@ func (bind *LinuxSocketBind) Close() error {
err2 = unix.Close(bind.sock4) err2 = unix.Close(bind.sock4)
bind.sock4 = -1 bind.sock4 = -1
} }
bind.closing.Unlock()
if err1 != nil { if err1 != nil {
return err1 return err1
@@ -208,46 +224,29 @@ func (bind *LinuxSocketBind) Close() error {
return err2 return err2
} }
func (bind *LinuxSocketBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) { func (*LinuxSocketBind) makeReceiveIPv6(sock int) ReceiveFunc {
bind.closing.RLock() return func(buff []byte) (int, Endpoint, error) {
defer bind.closing.RUnlock()
var end LinuxSocketEndpoint var end LinuxSocketEndpoint
if bind.sock6 == -1 { n, err := receive6(sock, buff, &end)
return 0, nil, net.ErrClosed
}
n, err := receive6(
bind.sock6,
buff,
&end,
)
return n, &end, err return n, &end, err
} }
func (bind *LinuxSocketBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) {
bind.closing.RLock()
defer bind.closing.RUnlock()
var end LinuxSocketEndpoint
if bind.sock4 == -1 {
return 0, nil, net.ErrClosed
} }
n, err := receive4(
bind.sock4, func (*LinuxSocketBind) makeReceiveIPv4(sock int) ReceiveFunc {
buff, return func(buff []byte) (int, Endpoint, error) {
&end, var end LinuxSocketEndpoint
) n, err := receive4(sock, buff, &end)
return n, &end, err return n, &end, err
} }
}
func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error { func (bind *LinuxSocketBind) Send(buff []byte, end Endpoint) error {
bind.closing.RLock()
defer bind.closing.RUnlock()
nend, ok := end.(*LinuxSocketEndpoint) nend, ok := end.(*LinuxSocketEndpoint)
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
bind.mu.RLock()
defer bind.mu.RUnlock()
if !nend.isV6 { if !nend.isV6 {
if bind.sock4 == -1 { if bind.sock4 == -1 {
return net.ErrClosed return net.ErrClosed

View File

@@ -8,6 +8,7 @@ package conn
import ( import (
"errors" "errors"
"net" "net"
"sync"
"syscall" "syscall"
) )
@@ -16,6 +17,7 @@ import (
// It uses the Go's net package to implement networking. // It uses the Go's net package to implement networking.
// See LinuxSocketBind for a proper implementation on the Linux platform. // See LinuxSocketBind for a proper implementation on the Linux platform.
type StdNetBind struct { type StdNetBind struct {
mu sync.Mutex // protects following fields
ipv4 *net.UDPConn ipv4 *net.UDPConn
ipv6 *net.UDPConn ipv6 *net.UDPConn
blackhole4 bool blackhole4 bool
@@ -81,44 +83,58 @@ func listenNet(network string, port int) (*net.UDPConn, int, error) {
return conn, uaddr.Port, nil return conn, uaddr.Port, nil
} }
func (bind *StdNetBind) Open(uport uint16) (uint16, error) { func (bind *StdNetBind) Open(uport uint16) ([]ReceiveFunc, uint16, error) {
bind.mu.Lock()
defer bind.mu.Unlock()
var err error var err error
var tries int var tries int
if bind.ipv4 != nil || bind.ipv6 != nil { if bind.ipv4 != nil || bind.ipv6 != nil {
return 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
// Attempt to open ipv4 and ipv6 listeners on the same port.
// If uport is 0, we can retry on failure.
again: again:
port := int(uport) port := int(uport)
var ipv4, ipv6 *net.UDPConn
bind.ipv4, port, err = listenNet("udp4", port) ipv4, port, err = listenNet("udp4", port)
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
bind.ipv4 = nil return nil, 0, err
return 0, err
} }
bind.ipv6, port, err = listenNet("udp6", port) // Listen on the same port as we're using for ipv4.
if uport == 0 && err != nil && errors.Is(err, syscall.EADDRINUSE) && tries < 100 { ipv6, port, err = listenNet("udp6", port)
bind.ipv4.Close() if uport == 0 && errors.Is(err, syscall.EADDRINUSE) && tries < 100 {
bind.ipv4 = nil ipv4.Close()
bind.ipv6 = nil
tries++ tries++
goto again goto again
} }
if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) { if err != nil && !errors.Is(err, syscall.EAFNOSUPPORT) {
bind.ipv4.Close() ipv4.Close()
bind.ipv4 = nil return nil, 0, err
bind.ipv6 = nil
return 0, err
} }
if bind.ipv4 == nil && bind.ipv6 == nil { var fns []ReceiveFunc
return 0, syscall.EAFNOSUPPORT if ipv4 != nil {
fns = append(fns, bind.makeReceiveIPv4(ipv4))
bind.ipv4 = ipv4
} }
return uint16(port), nil if ipv6 != nil {
fns = append(fns, bind.makeReceiveIPv6(ipv6))
bind.ipv6 = ipv6
}
if len(fns) == 0 {
return nil, 0, syscall.EAFNOSUPPORT
}
return fns, uint16(port), nil
} }
func (bind *StdNetBind) Close() error { func (bind *StdNetBind) Close() error {
bind.mu.Lock()
defer bind.mu.Unlock()
var err1, err2 error var err1, err2 error
if bind.ipv4 != nil { if bind.ipv4 != nil {
err1 = bind.ipv4.Close() err1 = bind.ipv4.Close()
@@ -136,24 +152,22 @@ func (bind *StdNetBind) Close() error {
return err2 return err2
} }
func (bind *StdNetBind) ReceiveIPv4(buff []byte) (int, Endpoint, error) { func (*StdNetBind) makeReceiveIPv4(conn *net.UDPConn) ReceiveFunc {
if bind.ipv4 == nil { return func(buff []byte) (int, Endpoint, error) {
return 0, nil, syscall.EAFNOSUPPORT n, endpoint, err := conn.ReadFromUDP(buff)
}
n, endpoint, err := bind.ipv4.ReadFromUDP(buff)
if endpoint != nil { if endpoint != nil {
endpoint.IP = endpoint.IP.To4() endpoint.IP = endpoint.IP.To4()
} }
return n, (*StdNetEndpoint)(endpoint), err return n, (*StdNetEndpoint)(endpoint), err
} }
func (bind *StdNetBind) ReceiveIPv6(buff []byte) (int, Endpoint, error) {
if bind.ipv6 == nil {
return 0, nil, syscall.EAFNOSUPPORT
} }
n, endpoint, err := bind.ipv6.ReadFromUDP(buff)
func (*StdNetBind) makeReceiveIPv6(conn *net.UDPConn) ReceiveFunc {
return func(buff []byte) (int, Endpoint, error) {
n, endpoint, err := conn.ReadFromUDP(buff)
return n, (*StdNetEndpoint)(endpoint), err return n, (*StdNetEndpoint)(endpoint), err
} }
}
func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error { func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
var err error var err error
@@ -161,22 +175,22 @@ func (bind *StdNetBind) Send(buff []byte, endpoint Endpoint) error {
if !ok { if !ok {
return ErrWrongEndpointType return ErrWrongEndpointType
} }
if nend.IP.To4() != nil {
if bind.ipv4 == nil { bind.mu.Lock()
return syscall.EAFNOSUPPORT blackhole := bind.blackhole4
conn := bind.ipv4
if nend.IP.To4() == nil {
blackhole = bind.blackhole6
conn = bind.ipv6
} }
if bind.blackhole4 { bind.mu.Unlock()
if blackhole {
return nil return nil
} }
_, err = bind.ipv4.WriteToUDP(buff, (*net.UDPAddr)(nend)) if conn == nil {
} else {
if bind.ipv6 == nil {
return syscall.EAFNOSUPPORT return syscall.EAFNOSUPPORT
} }
if bind.blackhole6 { _, err = conn.WriteToUDP(buff, (*net.UDPAddr)(nend))
return nil
}
_, err = bind.ipv6.WriteToUDP(buff, (*net.UDPAddr)(nend))
}
return err return err
} }

View File

@@ -47,7 +47,7 @@ func (rb *ringBuffer) Push() *ringPacket {
} }
ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{})))) ret := (*ringPacket)(unsafe.Pointer(rb.packets + (uintptr(rb.tail%packetsPerRing) * unsafe.Sizeof(ringPacket{}))))
rb.tail += 1 rb.tail += 1
if rb.tail == rb.head { if rb.tail%packetsPerRing == rb.head%packetsPerRing {
rb.isFull = true rb.isFull = true
} }
return ret return ret
@@ -197,6 +197,9 @@ func (ring *ringBuffer) CloseAndZero() {
windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE) windows.VirtualFree(ring.packets, 0, windows.MEM_RELEASE)
ring.packets = 0 ring.packets = 0
} }
ring.head = 0
ring.tail = 0
ring.isFull = false
} }
func (bind *afWinRingBind) CloseAndZero() { func (bind *afWinRingBind) CloseAndZero() {
@@ -266,7 +269,7 @@ func (bind *afWinRingBind) Open(family int32, sa windows.Sockaddr) (windows.Sock
return sa, nil return sa, nil
} }
func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) { func (bind *WinRingBind) Open(port uint16) (recvFns []ReceiveFunc, selectedPort uint16, err error) {
bind.mu.Lock() bind.mu.Lock()
defer bind.mu.Unlock() defer bind.mu.Unlock()
defer func() { defer func() {
@@ -275,30 +278,30 @@ func (bind *WinRingBind) Open(port uint16) (selectedPort uint16, err error) {
} }
}() }()
if atomic.LoadUint32(&bind.isOpen) != 0 { if atomic.LoadUint32(&bind.isOpen) != 0 {
return 0, ErrBindAlreadyOpen return nil, 0, ErrBindAlreadyOpen
} }
var sa windows.Sockaddr var sa windows.Sockaddr
sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)}) sa, err = bind.v4.Open(windows.AF_INET, &windows.SockaddrInet4{Port: int(port)})
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port}) sa, err = bind.v6.Open(windows.AF_INET6, &windows.SockaddrInet6{Port: sa.(*windows.SockaddrInet4).Port})
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
selectedPort = uint16(sa.(*windows.SockaddrInet6).Port) selectedPort = uint16(sa.(*windows.SockaddrInet6).Port)
for i := 0; i < packetsPerRing; i++ { for i := 0; i < packetsPerRing; i++ {
err = bind.v4.InsertReceiveRequest() err = bind.v4.InsertReceiveRequest()
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
err = bind.v6.InsertReceiveRequest() err = bind.v6.InsertReceiveRequest()
if err != nil { if err != nil {
return 0, err return nil, 0, err
} }
} }
atomic.StoreUint32(&bind.isOpen, 1) atomic.StoreUint32(&bind.isOpen, 1)
return return []ReceiveFunc{bind.receiveIPv4, bind.receiveIPv6}, selectedPort, err
} }
func (bind *WinRingBind) Close() error { func (bind *WinRingBind) Close() error {
@@ -395,13 +398,13 @@ func (bind *afWinRingBind) Receive(buf []byte, isOpen *uint32) (int, Endpoint, e
return n, &ep, nil return n, &ep, nil
} }
func (bind *WinRingBind) ReceiveIPv4(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv4(buf []byte) (int, Endpoint, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v4.Receive(buf, &bind.isOpen) return bind.v4.Receive(buf, &bind.isOpen)
} }
func (bind *WinRingBind) ReceiveIPv6(buf []byte) (int, Endpoint, error) { func (bind *WinRingBind) receiveIPv6(buf []byte) (int, Endpoint, error) {
bind.mu.RLock() bind.mu.RLock()
defer bind.mu.RUnlock() defer bind.mu.RUnlock()
return bind.v6.Receive(buf, &bind.isOpen) return bind.v6.Receive(buf, &bind.isOpen)
@@ -482,6 +485,8 @@ func (bind *WinRingBind) Send(buf []byte, endpoint Endpoint) error {
} }
func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error { func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv4.SyscallConn() sysconn, err := bind.ipv4.SyscallConn()
if err != nil { if err != nil {
return err return err
@@ -500,6 +505,8 @@ func (bind *StdNetBind) BindSocketToInterface4(interfaceIndex uint32, blackhole
} }
func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error { func (bind *StdNetBind) BindSocketToInterface6(interfaceIndex uint32, blackhole bool) error {
bind.mu.Lock()
defer bind.mu.Unlock()
sysconn, err := bind.ipv6.SyscallConn() sysconn, err := bind.ipv6.SyscallConn()
if err != nil { if err != nil {
return err return err

View File

@@ -65,12 +65,14 @@ func (c ChannelEndpoint) DstIP() net.IP { return net.IPv4(127, 0, 0, 1) }
func (c ChannelEndpoint) SrcIP() net.IP { return nil } func (c ChannelEndpoint) SrcIP() net.IP { return nil }
func (c *ChannelBind) Open(port uint16) (actualPort uint16, err error) { func (c *ChannelBind) Open(port uint16) (fns []conn.ReceiveFunc, actualPort uint16, err error) {
c.closeSignal = make(chan bool) c.closeSignal = make(chan bool)
fns = append(fns, c.makeReceiveFunc(*c.rx4))
fns = append(fns, c.makeReceiveFunc(*c.rx6))
if rand.Uint32()&1 == 0 { if rand.Uint32()&1 == 0 {
return uint16(c.source4), nil return fns, uint16(c.source4), nil
} else { } else {
return uint16(c.source6), nil return fns, uint16(c.source6), nil
} }
} }
@@ -87,22 +89,15 @@ func (c *ChannelBind) Close() error {
func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) SetMark(mark uint32) error { return nil }
func (c *ChannelBind) ReceiveIPv6(b []byte) (n int, ep conn.Endpoint, err error) { func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(b []byte) (n int, ep conn.Endpoint, err error) {
select { select {
case <-c.closeSignal: case <-c.closeSignal:
return 0, nil, net.ErrClosed return 0, nil, net.ErrClosed
case rx := <-*c.rx6: case rx := <-ch:
return copy(b, rx), c.target6, nil return copy(b, rx), c.target6, nil
} }
} }
func (c *ChannelBind) ReceiveIPv4(b []byte) (n int, ep conn.Endpoint, err error) {
select {
case <-c.closeSignal:
return 0, nil, net.ErrClosed
case rx := <-*c.rx4:
return copy(b, rx), c.target4, nil
}
} }
func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error { func (c *ChannelBind) Send(b []byte, ep conn.Endpoint) error {

View File

@@ -8,10 +8,18 @@ package conn
import ( import (
"errors" "errors"
"fmt"
"net" "net"
"reflect"
"runtime"
"strings" "strings"
) )
// A ReceiveFunc receives a single inbound packet from the network.
// It writes the data into b. n is the length of the packet.
// ep is the remote endpoint.
type ReceiveFunc func(b []byte) (n int, ep Endpoint, err error)
// A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
// //
// A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface, // A Bind interface may also be a PeekLookAtSocketFd or BindSocketToInterface,
@@ -19,23 +27,17 @@ import (
type Bind interface { type Bind interface {
// Open puts the Bind into a listening state on a given port and reports the actual // Open puts the Bind into a listening state on a given port and reports the actual
// port that it bound to. Passing zero results in a random selection. // port that it bound to. Passing zero results in a random selection.
Open(port uint16) (actualPort uint16, err error) // fns is the set of functions that will be called to receive packets.
Open(port uint16) (fns []ReceiveFunc, actualPort uint16, err error)
// Close closes the Bind listener. // Close closes the Bind listener.
// All fns returned by Open must return net.ErrClosed after a call to Close.
Close() error Close() error
// SetMark sets the mark for each packet sent through this Bind. // SetMark sets the mark for each packet sent through this Bind.
// This mark is passed to the kernel as the socket option SO_MARK. // This mark is passed to the kernel as the socket option SO_MARK.
SetMark(mark uint32) error SetMark(mark uint32) error
// ReceiveIPv6 reads an IPv6 UDP packet into b. It reports the number of bytes read,
// n, the packet source address ep, and any error.
ReceiveIPv6(b []byte) (n int, ep Endpoint, err error)
// ReceiveIPv4 reads an IPv4 UDP packet into b. It reports the number of bytes read,
// n, the packet source address ep, and any error.
ReceiveIPv4(b []byte) (n int, ep Endpoint, err error)
// Send writes a packet b to address ep. // Send writes a packet b to address ep.
Send(b []byte, ep Endpoint) error Send(b []byte, ep Endpoint) error
@@ -70,6 +72,54 @@ type Endpoint interface {
SrcIP() net.IP SrcIP() net.IP
} }
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)
func (fn ReceiveFunc) PrettyName() string {
name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
// 0. cheese/taco.beansIPv6.func12.func21218-fm
name = strings.TrimSuffix(name, "-fm")
// 1. cheese/taco.beansIPv6.func12.func21218
if idx := strings.LastIndexByte(name, '/'); idx != -1 {
name = name[idx+1:]
// 2. taco.beansIPv6.func12.func21218
}
for {
var idx int
for idx = len(name) - 1; idx >= 0; idx-- {
if name[idx] < '0' || name[idx] > '9' {
break
}
}
if idx == len(name)-1 {
break
}
const dotFunc = ".func"
if !strings.HasSuffix(name[:idx+1], dotFunc) {
break
}
name = name[:idx+1-len(dotFunc)]
// 3. taco.beansIPv6.func12
// 4. taco.beansIPv6
}
if idx := strings.LastIndexByte(name, '.'); idx != -1 {
name = name[idx+1:]
// 5. beansIPv6
}
if name == "" {
return fmt.Sprintf("%p", fn)
}
if strings.HasSuffix(name, "IPv4") {
return "v4"
}
if strings.HasSuffix(name, "IPv6") {
return "v6"
}
return name
}
func parseEndpoint(s string) (*net.UDPAddr, error) { func parseEndpoint(s string) (*net.UDPAddr, error) {
// ensure that the host is an IP address // ensure that the host is an IP address
@@ -99,8 +149,3 @@ func parseEndpoint(s string) (*net.UDPAddr, error) {
} }
return addr, err return addr, err
} }
var (
ErrBindAlreadyOpen = errors.New("bind is already open")
ErrWrongEndpointType = errors.New("endpoint type does not correspond with bind type")
)

View File

@@ -118,9 +118,17 @@ func Initialize() bool {
if err != nil { if err != nil {
return return
} }
// While we should be able to stop here, after getting the function pointers, some anti-virus actually causes // While we should be able to stop here, after getting the function pointers, some anti-virus actually causes
// failures in RIOCreateRequestQueue, so keep going to be certain this is supported. // failures in RIOCreateRequestQueue, so keep going to be certain this is supported.
cq, err = CreatePolledCompletionQueue(2) var iocp windows.Handle
iocp, err = windows.CreateIoCompletionPort(windows.InvalidHandle, 0, 0, 0)
if err != nil {
return
}
defer windows.CloseHandle(iocp)
var overlapped windows.Overlapped
cq, err = CreateIOCPCompletionQueue(2, iocp, 0, &overlapped)
if err != nil { if err != nil {
return return
} }
@@ -161,6 +169,7 @@ func CreateIOCPCompletionQueue(queueSize uint32, iocp windows.Handle, key uintpt
notificationCompletion := &iocpNotificationCompletion{ notificationCompletion := &iocpNotificationCompletion{
completionType: iocpCompletion, completionType: iocpCompletion,
iocp: iocp, iocp: iocp,
key: key,
overlapped: overlapped, overlapped: overlapped,
} }
ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0) ret, _, err := syscall.Syscall(extensionFunctionTable.rioCreateCompletionQueue, 2, uintptr(queueSize), uintptr(unsafe.Pointer(notificationCompletion)), 0)

View File

@@ -11,9 +11,6 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
"golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/conn"
"golang.zx2c4.com/wireguard/ratelimiter" "golang.zx2c4.com/wireguard/ratelimiter"
"golang.zx2c4.com/wireguard/rwcancel" "golang.zx2c4.com/wireguard/rwcancel"
@@ -400,7 +397,9 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() {
device.peers.RUnlock() device.peers.RUnlock()
} }
func unsafeCloseBind(device *Device) error { // closeBindLocked closes the device's net.bind.
// The caller must hold the net mutex.
func closeBindLocked(device *Device) error {
var err error var err error
netc := &device.net netc := &device.net
if netc.netlinkCancel != nil { if netc.netlinkCancel != nil {
@@ -455,7 +454,7 @@ func (device *Device) BindUpdate() error {
defer device.net.Unlock() defer device.net.Unlock()
// close existing sockets // close existing sockets
if err := unsafeCloseBind(device); err != nil { if err := closeBindLocked(device); err != nil {
return err return err
} }
@@ -466,8 +465,9 @@ func (device *Device) BindUpdate() error {
// bind to new port // bind to new port
var err error var err error
var recvFns []conn.ReceiveFunc
netc := &device.net netc := &device.net
netc.port, err = netc.bind.Open(netc.port) recvFns, netc.port, err = netc.bind.Open(netc.port)
if err != nil { if err != nil {
netc.port = 0 netc.port = 0
return err return err
@@ -499,11 +499,12 @@ func (device *Device) BindUpdate() error {
device.peers.RUnlock() device.peers.RUnlock()
// start receiving routines // start receiving routines
device.net.stopping.Add(2) device.net.stopping.Add(len(recvFns))
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind) for _, fn := range recvFns {
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind) go device.RoutineReceiveIncoming(fn)
}
device.log.Verbosef("UDP bind has been updated") device.log.Verbosef("UDP bind has been updated")
return nil return nil
@@ -511,7 +512,7 @@ func (device *Device) BindUpdate() error {
func (device *Device) BindClose() error { func (device *Device) BindClose() error {
device.net.Lock() device.net.Lock()
err := unsafeCloseBind(device) err := closeBindLocked(device)
device.net.Unlock() device.net.Unlock()
return err return err
} }

View File

@@ -9,8 +9,8 @@ func (device *Device) DisableSomeRoamingForBrokenMobileSemantics() {
device.peers.RLock() device.peers.RLock()
for _, peer := range device.peers.keyMap { for _, peer := range device.peers.keyMap {
peer.Lock() peer.Lock()
defer peer.Unlock()
peer.disableRoaming = peer.endpoint != nil peer.disableRoaming = peer.endpoint != nil
peer.Unlock()
} }
device.peers.RUnlock() device.peers.RUnlock()
} }

View File

@@ -68,15 +68,16 @@ func (peer *Peer) keepKeyFreshReceiving() {
* Every time the bind is updated a new routine is started for * Every time the bind is updated a new routine is started for
* IPv4 and IPv6 (separately) * IPv4 and IPv6 (separately)
*/ */
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) { func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
recvName := recv.PrettyName()
defer func() { defer func() {
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP) device.log.Verbosef("Routine: receive incoming %s - stopped", recvName)
device.queue.decryption.wg.Done() device.queue.decryption.wg.Done()
device.queue.handshake.wg.Done() device.queue.handshake.wg.Done()
device.net.stopping.Done() device.net.stopping.Done()
}() }()
device.log.Verbosef("Routine: receive incoming IPv%d - started", IP) device.log.Verbosef("Routine: receive incoming %s - started", recvName)
// receive datagrams until conn is closed // receive datagrams until conn is closed
@@ -90,24 +91,21 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
) )
for { for {
switch IP { size, endpoint, err = recv(buffer[:])
case ipv4.Version:
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
case ipv6.Version:
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
default:
panic("invalid IP version")
}
if err != nil { if err != nil {
device.PutMessageBuffer(buffer) device.PutMessageBuffer(buffer)
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return return
} }
if neterr, ok := err.(net.Error); ok && !neterr.Temporary() {
return
}
device.log.Errorf("Failed to receive packet: %v", err) device.log.Errorf("Failed to receive packet: %v", err)
if deathSpiral < 10 { if deathSpiral < 10 {
deathSpiral++ deathSpiral++
time.Sleep(time.Second / 3) time.Sleep(time.Second / 3)
buffer = device.GetMessageBuffer()
continue continue
} }
return return

23
main.go
View File

@@ -33,25 +33,28 @@ const (
) )
func printUsage() { func printUsage() {
fmt.Printf("usage:\n") fmt.Printf("Usage: %s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
fmt.Printf("%s [-f/--foreground] INTERFACE-NAME\n", os.Args[0])
} }
func warning() { func warning() {
if runtime.GOOS != "linux" || os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" { switch runtime.GOOS {
case "linux", "freebsd", "openbsd":
if os.Getenv(ENV_WG_PROCESS_FOREGROUND) == "1" {
return
}
default:
return return
} }
fmt.Fprintln(os.Stderr, "┌───────────────────────────────────────────────────┐") fmt.Fprintln(os.Stderr, "┌──────────────────────────────────────────────────────┐")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "│ Running this software on Linux is unnecessary, │") fmt.Fprintln(os.Stderr, "│ Running wireguard-go is not required because this │")
fmt.Fprintln(os.Stderr, "│ because the Linux kernel has built-in first │") fmt.Fprintln(os.Stderr, "│ kernel has first class support for WireGuard. For │")
fmt.Fprintln(os.Stderr, "│ class support for WireGuard, which will be │")
fmt.Fprintln(os.Stderr, "│ faster, slicker, and better integrated. For │")
fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │") fmt.Fprintln(os.Stderr, "│ information on installing the kernel module, │")
fmt.Fprintln(os.Stderr, "│ please visit: <https://wireguard.com/install>. │") fmt.Fprintln(os.Stderr, "│ please visit: │")
fmt.Fprintln(os.Stderr, "│ https://www.wireguard.com/install/ │")
fmt.Fprintln(os.Stderr, "│ │") fmt.Fprintln(os.Stderr, "│ │")
fmt.Fprintln(os.Stderr, "└───────────────────────────────────────────────────┘") fmt.Fprintln(os.Stderr, "└──────────────────────────────────────────────────────┘")
} }
func main() { func main() {

View File

@@ -6,62 +6,52 @@
package tun package tun
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io"
"net" "net"
"os" "os"
"sync" "sync"
"syscall" "syscall"
"unsafe" "unsafe"
"golang.org/x/net/ipv6"
"golang.org/x/sys/unix" "golang.org/x/sys/unix"
) )
// _TUNSIFHEAD, value derived from sys/net/{if_tun,ioccom}.h
// const _TUNSIFHEAD = ((0x80000000) | (((4) & ((1 << 13) - 1) ) << 16) | (uint32(byte('t')) << 8) | (96))
const ( const (
_TUNSIFHEAD = 0x80047460 _TUNSIFHEAD = 0x80047460
_TUNSIFMODE = 0x8004745e _TUNSIFMODE = 0x8004745e
_TUNGIFNAME = 0x4020745d
_TUNSIFPID = 0x2000745f _TUNSIFPID = 0x2000745f
_SIOCGIFINFO_IN6 = 0xc048696c
_SIOCSIFINFO_IN6 = 0xc048696d
_ND6_IFF_AUTO_LINKLOCAL = 0x20
_ND6_IFF_NO_DAD = 0x100
) )
// TODO: move into x/sys/unix // Iface requests with just the name
const ( type ifreqName struct {
SIOCGIFINFO_IN6 = 0xc048696c Name [unix.IFNAMSIZ]byte
SIOCSIFINFO_IN6 = 0xc048696d _ [16]byte
ND6_IFF_AUTO_LINKLOCAL = 0x20 }
ND6_IFF_NO_DAD = 0x100
)
// Iface status string max len // Iface requests with a pointer
const _IFSTATMAX = 800 type ifreqPtr struct {
const SIZEOF_UINTPTR = 4 << (^uintptr(0) >> 32 & 1)
// structure for iface requests with a pointer
type ifreq_ptr struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Data uintptr Data uintptr
Pad0 [16 - SIZEOF_UINTPTR]byte _ [16 - unsafe.Sizeof(uintptr(0))]byte
} }
// Structure for iface mtu get/set ioctls // Iface requests with MTU
type ifreq_mtu struct { type ifreqMtu struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
MTU uint32 MTU uint32
Pad0 [12]byte _ [12]byte
} }
// Structure for interface status request ioctl // ND6 flag manipulation
type ifstat struct { type nd6Req struct {
IfsName [unix.IFNAMSIZ]byte
Ascii [_IFSTATMAX]byte
}
// Structures for nd6 flag manipulation
type in6_ndireq struct {
Name [unix.IFNAMSIZ]byte Name [unix.IFNAMSIZ]byte
Linkmtu uint32 Linkmtu uint32
Maxmtu uint32 Maxmtu uint32
@@ -99,7 +89,7 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
retry: retry:
n, err := unix.Read(tun.routeSocket, data) n, err := unix.Read(tun.routeSocket, data)
if err != nil { if err != nil {
if errno, ok := err.(syscall.Errno); ok && errno == syscall.EINTR { if errors.Is(err, syscall.EINTR) {
goto retry goto retry
} }
tun.errors <- err tun.errors <- err
@@ -143,91 +133,17 @@ func (tun *NativeTun) routineRouteListener(tunIfindex int) {
} }
func tunName(fd uintptr) (string, error) { func tunName(fd uintptr) (string, error) {
//Terrible hack to make up for freebsd not having a TUNGIFNAME var ifreq ifreqName
_, _, err := unix.Syscall(unix.SYS_IOCTL, fd, _TUNGIFNAME, uintptr(unsafe.Pointer(&ifreq)))
//First, make sure the tun pid matches this proc's pid if err != 0 {
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(_TUNSIFPID),
uintptr(0),
)
if errno != 0 {
return "", fmt.Errorf("failed to set tun device PID: %s", errno.Error())
}
// Open iface control socket
confd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil {
return "", err return "", err
} }
return unix.ByteSliceToString(ifreq.Name[:]), nil
defer unix.Close(confd)
procPid := os.Getpid()
//Try to find interface with matching PID
for i := 1; ; i++ {
iface, _ := net.InterfaceByIndex(i)
if err != nil || iface == nil {
break
}
// Structs for getting data in and out of SIOCGIFSTATUS ioctl
var ifstatus ifstat
copy(ifstatus.IfsName[:], iface.Name)
// Make the syscall to get the status string
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCGIFSTATUS),
uintptr(unsafe.Pointer(&ifstatus)),
)
if errno != 0 {
continue
}
nullStr := ifstatus.Ascii[:]
i := bytes.IndexByte(nullStr, 0)
if i < 1 {
continue
}
statStr := string(nullStr[:i])
var pidNum int = 0
// Finally get the owning PID
// Format string taken from sys/net/if_tun.c
_, err := fmt.Sscanf(statStr, "\tOpened by PID %d\n", &pidNum)
if err != nil {
continue
}
if pidNum == procPid {
return iface.Name, nil
}
}
return "", nil
} }
// Destroy a named system interface // Destroy a named system interface
func tunDestroy(name string) error { func tunDestroy(name string) error {
// Open control socket. fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return err return err
} }
@@ -235,14 +151,9 @@ func tunDestroy(name string) error {
var ifr [32]byte var ifr [32]byte
copy(ifr[:], name) copy(ifr[:], name)
_, _, errno := unix.Syscall( _, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCIFDESTROY), uintptr(unsafe.Pointer(&ifr[0])))
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCIFDESTROY),
uintptr(unsafe.Pointer(&ifr[0])),
)
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to destroy interface %s: %s", name, errno.Error()) return fmt.Errorf("failed to destroy interface %s: %w", name, errno)
} }
return nil return nil
@@ -278,104 +189,68 @@ func CreateTUN(name string, mtu int) (Device, error) {
ifheadmode := 1 ifheadmode := 1
var errno syscall.Errno var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) { tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFHEAD, uintptr(unsafe.Pointer(&ifheadmode)))
unix.SYS_IOCTL,
fd,
uintptr(_TUNSIFHEAD),
uintptr(unsafe.Pointer(&ifheadmode)),
)
}) })
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFHEAD mode: %w", errno) return nil, fmt.Errorf("unable to put into IFHEAD mode: %w", errno)
} }
// Get out of PPP mode. // Get out of PTP mode.
ifflags := syscall.IFF_BROADCAST ifflags := syscall.IFF_BROADCAST | syscall.IFF_MULTICAST
tun.operateOnFd(func(fd uintptr) { tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, uintptr(_TUNSIFMODE), uintptr(unsafe.Pointer(&ifflags)))
unix.SYS_IOCTL,
fd,
uintptr(_TUNSIFMODE),
uintptr(unsafe.Pointer(&ifflags)),
)
}) })
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to put into IFF_BROADCAST mode: %w", errno) return nil, fmt.Errorf("unable to put into IFF_BROADCAST mode: %w", errno)
} }
// Open control sockets // Disable link-local v6, not just because WireGuard doesn't do that anyway, but
confd, err := unix.Socket( // also because there are serious races with attaching and detaching LLv6 addresses
unix.AF_INET, // in relation to interface lifetime within the FreeBSD kernel.
unix.SOCK_DGRAM, confd6, err := unix.Socket(unix.AF_INET6, unix.SOCK_DGRAM, 0)
0,
)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd)
confd6, err := unix.Socket(
unix.AF_INET6,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, err return nil, err
} }
defer unix.Close(confd6) defer unix.Close(confd6)
var ndireq nd6Req
// Disable link-local v6, not just because WireGuard doesn't do that anyway, but
// also because there are serious races with attaching and detaching LLv6 addresses
// in relation to interface lifetime within the FreeBSD kernel.
var ndireq in6_ndireq
copy(ndireq.Name[:], assignedName) copy(ndireq.Name[:], assignedName)
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCGIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCGIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to get nd6 flags for %s: %w", assignedName, errno) return nil, fmt.Errorf("unable to get nd6 flags for %s: %w", assignedName, errno)
} }
ndireq.Flags = ndireq.Flags &^ ND6_IFF_AUTO_LINKLOCAL ndireq.Flags = ndireq.Flags &^ _ND6_IFF_AUTO_LINKLOCAL
ndireq.Flags = ndireq.Flags | ND6_IFF_NO_DAD ndireq.Flags = ndireq.Flags | _ND6_IFF_NO_DAD
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd6), uintptr(_SIOCSIFINFO_IN6), uintptr(unsafe.Pointer(&ndireq)))
unix.SYS_IOCTL,
uintptr(confd6),
uintptr(SIOCSIFINFO_IN6),
uintptr(unsafe.Pointer(&ndireq)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
return nil, fmt.Errorf("Unable to set nd6 flags for %s: %w", assignedName, errno) return nil, fmt.Errorf("unable to set nd6 flags for %s: %w", assignedName, errno)
} }
if name != "" { if name != "" {
// Rename the interface confd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
if err != nil {
tunFile.Close()
tunDestroy(assignedName)
return nil, err
}
defer unix.Close(confd)
var newnp [unix.IFNAMSIZ]byte var newnp [unix.IFNAMSIZ]byte
copy(newnp[:], name) copy(newnp[:], name)
var ifr ifreq_ptr var ifr ifreqPtr
copy(ifr.Name[:], assignedName) copy(ifr.Name[:], assignedName)
ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) ifr.Data = uintptr(unsafe.Pointer(&newnp[0]))
_, _, errno = unix.Syscall( _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr)))
unix.SYS_IOCTL,
uintptr(confd),
uintptr(unix.SIOCSIFNAME),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
tunFile.Close() tunFile.Close()
tunDestroy(assignedName) tunDestroy(assignedName)
@@ -387,13 +262,21 @@ func CreateTUN(name string, mtu int) (Device, error) {
} }
func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { func CreateTUNFromFile(file *os.File, mtu int) (Device, error) {
tun := &NativeTun{ tun := &NativeTun{
tunFile: file, tunFile: file,
events: make(chan Event, 10), events: make(chan Event, 10),
errors: make(chan error, 1), errors: make(chan error, 1),
} }
var errno syscall.Errno
tun.operateOnFd(func(fd uintptr) {
_, _, errno = unix.Syscall(unix.SYS_IOCTL, fd, _TUNSIFPID, uintptr(0))
})
if errno != 0 {
tun.tunFile.Close()
return nil, fmt.Errorf("unable to become controlling TUN process: %w", errno)
}
name, err := tun.Name() name, err := tun.Name()
if err != nil { if err != nil {
tun.tunFile.Close() tun.tunFile.Close()
@@ -464,27 +347,26 @@ func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
} }
} }
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) { func (tun *NativeTun) Write(buf []byte, offset int) (int, error) {
if offset < 4 {
// reserve space for header return 0, io.ErrShortBuffer
buff = buff[offset-4:]
// add packet information header
buff[0] = 0x00
buff[1] = 0x00
buff[2] = 0x00
if buff[4]>>4 == ipv6.Version {
buff[3] = unix.AF_INET6
} else {
buff[3] = unix.AF_INET
} }
buf = buf[offset-4:]
// write if len(buf) < 5 {
return 0, io.ErrShortBuffer
return tun.tunFile.Write(buff) }
buf[0] = 0x00
buf[1] = 0x00
buf[2] = 0x00
switch buf[4] >> 4 {
case 4:
buf[3] = unix.AF_INET
case 6:
buf[3] = unix.AF_INET6
default:
return 0, unix.EAFNOSUPPORT
}
return tun.tunFile.Write(buf)
} }
func (tun *NativeTun) Flush() error { func (tun *NativeTun) Flush() error {
@@ -515,70 +397,34 @@ func (tun *NativeTun) Close() error {
} }
func (tun *NativeTun) setMTU(n int) error { func (tun *NativeTun) setMTU(n int) error {
// open datagram socket fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
var fd int
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return err return err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call var ifr ifreqMtu
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name) copy(ifr.Name[:], tun.name)
ifr.MTU = uint32(n) ifr.MTU = uint32(n)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCSIFMTU), uintptr(unsafe.Pointer(&ifr)))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCSIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
return fmt.Errorf("failed to set MTU on %s", tun.name) return fmt.Errorf("failed to set MTU on %s: %w", tun.name, errno)
} }
return nil return nil
} }
func (tun *NativeTun) MTU() (int, error) { func (tun *NativeTun) MTU() (int, error) {
// open datagram socket fd, err := unix.Socket(unix.AF_INET, unix.SOCK_DGRAM, 0)
fd, err := unix.Socket(
unix.AF_INET,
unix.SOCK_DGRAM,
0,
)
if err != nil { if err != nil {
return 0, err return 0, err
} }
defer unix.Close(fd) defer unix.Close(fd)
// do ioctl call var ifr ifreqMtu
var ifr ifreq_mtu
copy(ifr.Name[:], tun.name) copy(ifr.Name[:], tun.name)
_, _, errno := unix.Syscall(unix.SYS_IOCTL, uintptr(fd), uintptr(unix.SIOCGIFMTU), uintptr(unsafe.Pointer(&ifr)))
_, _, errno := unix.Syscall(
unix.SYS_IOCTL,
uintptr(fd),
uintptr(unix.SIOCGIFMTU),
uintptr(unsafe.Pointer(&ifr)),
)
if errno != 0 { if errno != 0 {
return 0, fmt.Errorf("failed to get MTU on %s", tun.name) return 0, fmt.Errorf("failed to get MTU on %s: %w", tun.name, errno)
} }
return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil return int(*(*int32)(unsafe.Pointer(&ifr.MTU))), nil
} }

View File

@@ -1,3 +1,3 @@
package main package main
const Version = "0.0.20210323" const Version = "0.0.20210424"