device: simplify peer queue locking

Signed-off-by: Jason A. Donenfeld <Jason@zx2c4.com>
This commit is contained in:
Jason A. Donenfeld
2021-01-29 14:54:11 +01:00
parent f0f27d7fd2
commit 9263014ed3
4 changed files with 70 additions and 147 deletions

View File

@@ -174,7 +174,6 @@ func (device *Device) RoutineReceiveIncoming(IP int, bind conn.Bind) {
elem.Lock()
// add to decryption queues
peer.queue.RLock()
if peer.isRunning.Get() {
peer.queue.inbound <- elem
@@ -433,52 +432,25 @@ func (device *Device) RoutineHandshake() {
func (peer *Peer) RoutineSequentialReceiver() {
device := peer.device
var elem *QueueInboundElement
defer func() {
device.log.Verbosef("%v - Routine: sequential receiver - stopped", peer)
peer.routines.stopping.Done()
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
peer.stopping.Done()
}()
device.log.Verbosef("%v - Routine: sequential receiver - started", peer)
for {
if elem != nil {
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
elem = nil
}
var elemOk bool
select {
case <-peer.routines.stop:
return
case elem, elemOk = <-peer.queue.inbound:
if !elemOk {
return
}
}
// wait for decryption
for elem := range peer.queue.inbound {
var err error
elem.Lock()
if elem.packet == nil {
// decryption failed
continue
goto skip
}
// check for replay
if !elem.keypair.replayFilter.ValidateCounter(elem.counter, RejectAfterMessages) {
continue
goto skip
}
// update endpoint
peer.SetEndpointFromPacket(elem.endpoint)
// check if using new keypair
if peer.ReceivedWithKeypair(elem.keypair) {
peer.timersHandshakeComplete()
peer.SendStagedPackets()
@@ -489,83 +461,63 @@ func (peer *Peer) RoutineSequentialReceiver() {
peer.timersAnyAuthenticatedPacketReceived()
atomic.AddUint64(&peer.stats.rxBytes, uint64(len(elem.packet)+MinMessageSize))
// check for keepalive
if len(elem.packet) == 0 {
device.log.Verbosef("%v - Receiving keepalive packet", peer)
continue
goto skip
}
peer.timersDataReceived()
// verify source and strip padding
switch elem.packet[0] >> 4 {
case ipv4.Version:
// strip padding
if len(elem.packet) < ipv4.HeaderLen {
continue
goto skip
}
field := elem.packet[IPv4offsetTotalLength : IPv4offsetTotalLength+2]
length := binary.BigEndian.Uint16(field)
if int(length) > len(elem.packet) || int(length) < ipv4.HeaderLen {
continue
goto skip
}
elem.packet = elem.packet[:length]
// verify IPv4 source
src := elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len]
if device.allowedips.LookupIPv4(src) != peer {
device.log.Verbosef("IPv4 packet with disallowed source address from %v", peer)
continue
goto skip
}
case ipv6.Version:
// strip padding
if len(elem.packet) < ipv6.HeaderLen {
continue
goto skip
}
field := elem.packet[IPv6offsetPayloadLength : IPv6offsetPayloadLength+2]
length := binary.BigEndian.Uint16(field)
length += ipv6.HeaderLen
if int(length) > len(elem.packet) {
continue
goto skip
}
elem.packet = elem.packet[:length]
// verify IPv6 source
src := elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len]
if device.allowedips.LookupIPv6(src) != peer {
device.log.Verbosef("IPv6 packet with disallowed source address from %v", peer)
continue
goto skip
}
default:
device.log.Verbosef("Packet with invalid IP version from %v", peer)
continue
goto skip
}
// write to tun device
offset := MessageTransportOffsetContent
_, err := device.tun.device.Write(elem.buffer[:offset+len(elem.packet)], offset)
_, err = device.tun.device.Write(elem.buffer[:MessageTransportOffsetContent+len(elem.packet)], MessageTransportOffsetContent)
if err != nil && !device.isClosed.Get() {
device.log.Errorf("Failed to write packet to TUN device: %v", err)
}
if len(peer.queue.inbound) == 0 {
err := device.tun.device.Flush()
err = device.tun.device.Flush()
if err != nil {
peer.device.log.Errorf("Unable to flush packets: %v", err)
}
}
skip:
device.PutMessageBuffer(elem.buffer)
device.PutInboundElement(elem)
}
}