all: make conn.Bind.Open return a slice of receive functions
Instead of hard-coding exactly two sources from which to receive packets (an IPv4 source and an IPv6 source), allow the conn.Bind to specify a set of sources. Beneficial consequences: * If there's no IPv6 support on a system, conn.Bind.Open can choose not to return a receive function for it, which is simpler than tracking that state in the bind. This simplification removes existing data races from both conn.StdNetBind and bindtest.ChannelBind. * If there are more than two sources on a system, the conn.Bind no longer needs to add a separate muxing layer. Signed-off-by: Josh Bleecher Snyder <josharian@gmail.com>
This commit is contained in:
committed by
Jason A. Donenfeld
parent
8ed83e0427
commit
10533c3e73
@@ -11,9 +11,6 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"golang.org/x/net/ipv4"
|
||||
"golang.org/x/net/ipv6"
|
||||
|
||||
"golang.zx2c4.com/wireguard/conn"
|
||||
"golang.zx2c4.com/wireguard/ratelimiter"
|
||||
"golang.zx2c4.com/wireguard/rwcancel"
|
||||
@@ -468,8 +465,9 @@ func (device *Device) BindUpdate() error {
|
||||
|
||||
// bind to new port
|
||||
var err error
|
||||
var recvFns []conn.ReceiveFunc
|
||||
netc := &device.net
|
||||
netc.port, err = netc.bind.Open(netc.port)
|
||||
recvFns, netc.port, err = netc.bind.Open(netc.port)
|
||||
if err != nil {
|
||||
netc.port = 0
|
||||
return err
|
||||
@@ -501,11 +499,12 @@ func (device *Device) BindUpdate() error {
|
||||
device.peers.RUnlock()
|
||||
|
||||
// start receiving routines
|
||||
device.net.stopping.Add(2)
|
||||
device.queue.decryption.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(2) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
go device.RoutineReceiveIncoming(ipv4.Version, netc.bind)
|
||||
go device.RoutineReceiveIncoming(ipv6.Version, netc.bind)
|
||||
device.net.stopping.Add(len(recvFns))
|
||||
device.queue.decryption.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.decryption
|
||||
device.queue.handshake.wg.Add(len(recvFns)) // each RoutineReceiveIncoming goroutine writes to device.queue.handshake
|
||||
for _, fn := range recvFns {
|
||||
go device.RoutineReceiveIncoming(fn)
|
||||
}
|
||||
|
||||
device.log.Verbosef("UDP bind has been updated")
|
||||
return nil
|
||||
|
||||
@@ -68,15 +68,15 @@ func (peer *Peer) keepKeyFreshReceiving() {
|
||||
* Every time the bind is updated a new routine is started for
|
||||
* IPv4 and IPv6 (separately)
|
||||
*/
|
||||
func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||
func (device *Device) RoutineReceiveIncoming(recv conn.ReceiveFunc) {
|
||||
defer func() {
|
||||
device.log.Verbosef("Routine: receive incoming IPv%d - stopped", IP)
|
||||
device.log.Verbosef("Routine: receive incoming %p - stopped", recv)
|
||||
device.queue.decryption.wg.Done()
|
||||
device.queue.handshake.wg.Done()
|
||||
device.net.stopping.Done()
|
||||
}()
|
||||
|
||||
device.log.Verbosef("Routine: receive incoming IPv%d - started", IP)
|
||||
device.log.Verbosef("Routine: receive incoming %p - started", recv)
|
||||
|
||||
// receive datagrams until conn is closed
|
||||
|
||||
@@ -90,14 +90,7 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
|
||||
)
|
||||
|
||||
for {
|
||||
switch IP {
|
||||
case ipv4.Version:
|
||||
size, endpoint, err = bind.ReceiveIPv4(buffer[:])
|
||||
case ipv6.Version:
|
||||
size, endpoint, err = bind.ReceiveIPv6(buffer[:])
|
||||
default:
|
||||
panic("invalid IP version")
|
||||
}
|
||||
size, endpoint, err = recv(buffer[:])
|
||||
|
||||
if err != nil {
|
||||
device.PutMessageBuffer(buffer)
|
||||
|
||||
Reference in New Issue
Block a user