conn: make binds replacable
Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
@@ -279,11 +279,12 @@ func (device *Device) SetPrivateKey(sk NoisePrivateKey) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||
func NewDevice(tunDevice tun.Device, bind conn.Bind, logger *Logger) *Device {
|
||||
device := new(Device)
|
||||
device.state.state = uint32(deviceStateDown)
|
||||
device.closed = make(chan struct{})
|
||||
device.log = logger
|
||||
device.net.bind = bind
|
||||
device.tun.device = tunDevice
|
||||
mtu, err := device.tun.device.MTU()
|
||||
if err != nil {
|
||||
@@ -302,11 +303,6 @@ func NewDevice(tunDevice tun.Device, logger *Logger) *Device {
|
||||
device.queue.encryption = newOutboundQueue()
|
||||
device.queue.decryption = newInboundQueue()
|
||||
|
||||
// prepare net
|
||||
|
||||
device.net.port = 0
|
||||
device.net.bind = nil
|
||||
|
||||
// start workers
|
||||
|
||||
cpus := runtime.NumCPU()
|
||||
@@ -414,7 +410,6 @@ func unsafeCloseBind(device *Device) error {
|
||||
}
|
||||
if netc.bind != nil {
|
||||
err = netc.bind.Close()
|
||||
netc.bind = nil
|
||||
}
|
||||
netc.stopping.Wait()
|
||||
return err
|
||||
@@ -474,16 +469,14 @@ func (device *Device) BindUpdate() error {
|
||||
// bind to new port
|
||||
var err error
|
||||
netc := &device.net
|
||||
netc.bind, netc.port, err = conn.CreateBind(netc.port)
|
||||
netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.bind = nil
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
netc.netlinkCancel, err = device.startRouteListener(netc.bind)
|
||||
if err != nil {
|
||||
netc.bind.Close()
|
||||
netc.bind = nil
|
||||
netc.port = 0
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/tun/tuntest"
|
||||
)
|
||||
|
||||
@@ -158,7 +159,7 @@ func genTestPair(tb testing.TB) (pair testPair) {
|
||||
if _, ok := tb.(*testing.B); ok && !testing.Verbose() {
|
||||
level = LogLevelError
|
||||
}
|
||||
p.dev = NewDevice(p.tun.TUN(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||
p.dev = NewDevice(p.tun.TUN(), conn.NewDefaultBind(), NewLogger(level, fmt.Sprintf("dev%d: ", i)))
|
||||
if err := p.dev.IpcSet(cfg[i]); err != nil {
|
||||
tb.Errorf("failed to configure device %d: %v", i, err)
|
||||
p.dev.Close()
|
||||
@@ -332,7 +333,7 @@ func randDevice(t *testing.T) *Device {
|
||||
}
|
||||
tun := newDummyTUN("dummy")
|
||||
logger := NewLogger(LogLevelError, "")
|
||||
device := NewDevice(tun, logger)
|
||||
device := NewDevice(tun, conn.NewDefaultBind(), logger)
|
||||
device.SetPrivateKey(sk)
|
||||
return device
|
||||
}
|
||||
|
||||
@@ -126,13 +126,8 @@ func (peer *Peer) SendBuffer(buffer []byte) error {
|
||||
peer.device.net.RLock()
|
||||
defer peer.device.net.RUnlock()
|
||||
|
||||
if peer.device.net.bind == nil {
|
||||
// Packets can leak through to SendBuffer while the device is closing.
|
||||
// When that happens, drop them silently to avoid spurious errors.
|
||||
if peer.device.isClosed() {
|
||||
return nil
|
||||
}
|
||||
return errors.New("no bind")
|
||||
if peer.device.isClosed() {
|
||||
return nil
|
||||
}
|
||||
|
||||
peer.RLock()
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
// +build !linux android
|
||||
// +build !linux
|
||||
|
||||
package device
|
||||
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
// +build !android
|
||||
|
||||
/* SPDX-License-Identifier: MIT
|
||||
*
|
||||
* Copyright (C) 2017-2021 WireGuard LLC. All Rights Reserved.
|
||||
@@ -21,11 +19,16 @@ import (
|
||||
"unsafe"
|
||||
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
)
|
||||
|
||||
func (device *Device) startRouteListener(bind conn.Bind) (*rwcancel.RWCancel, error) {
|
||||
if _, ok := bind.(*conn.LinuxSocketBind); !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
netlinkSock, err := createNetlinkRouteSocket()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -109,11 +112,11 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
if uint32(pePtr.peer.endpoint.(*conn.NativeEndpoint).Src4().Ifindex) == ifidx {
|
||||
if uint32(pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).Src4().Ifindex) == ifidx {
|
||||
pePtr.peer.Unlock()
|
||||
break
|
||||
}
|
||||
pePtr.peer.endpoint.(*conn.NativeEndpoint).ClearSrc()
|
||||
pePtr.peer.endpoint.(*conn.LinuxSocketEndpoint).ClearSrc()
|
||||
pePtr.peer.Unlock()
|
||||
}
|
||||
attr = attr[attrhdr.Len:]
|
||||
@@ -133,7 +136,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
peer.RUnlock()
|
||||
continue
|
||||
}
|
||||
nativeEP, _ := peer.endpoint.(*conn.NativeEndpoint)
|
||||
nativeEP, _ := peer.endpoint.(*conn.LinuxSocketEndpoint)
|
||||
if nativeEP == nil {
|
||||
peer.RUnlock()
|
||||
continue
|
||||
@@ -176,7 +179,7 @@ func (device *Device) routineRouteListener(bind conn.Bind, netlinkSock int, netl
|
||||
Len: 8,
|
||||
Type: unix.RTA_MARK,
|
||||
},
|
||||
uint32(bind.LastMark()),
|
||||
device.net.fwmark,
|
||||
}
|
||||
nlmsg.hdr.Len = uint32(unsafe.Sizeof(nlmsg))
|
||||
reqPeerLock.Lock()
|
||||
|
||||
@@ -18,7 +18,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
@@ -331,7 +330,7 @@ func (device *Device) handlePeerLine(peer *ipcSetPeer, key, value string) error
|
||||
|
||||
case "endpoint":
|
||||
device.log.Verbosef("%v - UAPI: Updating endpoint", peer.Peer)
|
||||
endpoint, err := conn.CreateEndpoint(value)
|
||||
endpoint, err := device.net.bind.ParseEndpoint(value)
|
||||
if err != nil {
|
||||
return ipcErrorf(ipc.IpcErrorInvalid, "failed to set endpoint %v: %w", value, err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user