Rewrite timers and related state machines

This commit is contained in:
Jason A. Donenfeld
2018-05-07 22:27:03 +02:00
parent 375dcbd4ae
commit 233f079a94
14 changed files with 453 additions and 602 deletions

View File

@@ -31,7 +31,7 @@ type QueueInboundElement struct {
buffer *[MaxMessageSize]byte
packet []byte
counter uint64
keyPair *KeyPair
keyPair *Keypair
endpoint Endpoint
}
@@ -99,6 +99,21 @@ func (device *Device) addToHandshakeQueue(
}
}
/* Called when a new authenticated message has been received
*
* NOTE: Not thread safe, but called by sequential receiver!
*/
func (peer *Peer) keepKeyFreshReceiving() {
if peer.timers.sentLastMinuteHandshake {
return
}
kp := peer.keyPairs.Current()
if kp != nil && kp.isInitiator && time.Now().Sub(kp.created) > (RejectAfterTime-KeepaliveTimeout-RekeyTimeout) {
peer.timers.sentLastMinuteHandshake = true
peer.SendHandshakeInitiation(false)
}
}
/* Receives incoming datagrams for the device
*
* Every time the bind is updated a new routine is started for
@@ -245,7 +260,7 @@ func (device *Device) RoutineDecryption() {
for {
select {
case <-device.signal.stop.Wait():
case <-device.signals.stop:
return
case elem, ok := <-device.queue.decryption:
@@ -317,7 +332,7 @@ func (device *Device) RoutineHandshake() {
for {
select {
case elem, ok = <-device.queue.handshake:
case <-device.signal.stop.Wait():
case <-device.signals.stop:
return
}
@@ -441,8 +456,8 @@ func (device *Device) RoutineHandshake() {
// update timers
peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.event.anyAuthenticatedPacketReceived.Fire()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// update endpoint
@@ -460,10 +475,11 @@ func (device *Device) RoutineHandshake() {
continue
}
peer.TimerEphemeralKeyCreated()
peer.NewKeyPair()
if peer.NewKeypair() == nil {
continue
}
logDebug.Println(peer, ": Creating handshake response")
logDebug.Println(peer, ": Sending handshake response")
writer := bytes.NewBuffer(temp[:0])
binary.Write(writer, binary.LittleEndian, response)
@@ -472,9 +488,10 @@ func (device *Device) RoutineHandshake() {
// send response
peer.timers.lastSentHandshake = time.Now()
err = peer.SendBuffer(packet)
if err == nil {
peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.timersAnyAuthenticatedPacketTraversal()
} else {
logError.Println(peer, ": Failed to send handshake response", err)
}
@@ -510,18 +527,23 @@ func (device *Device) RoutineHandshake() {
logDebug.Println(peer, ": Received handshake response")
peer.TimerEphemeralKeyCreated()
// update timers
peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.event.anyAuthenticatedPacketReceived.Fire()
peer.event.handshakeCompleted.Fire()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// derive key-pair
peer.NewKeyPair()
peer.SendKeepAlive()
if peer.NewKeypair() == nil {
continue
}
peer.timersHandshakeComplete()
peer.SendKeepalive()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
}
}
}
@@ -569,38 +591,41 @@ func (peer *Peer) RoutineSequentialReceiver() {
continue
}
peer.event.anyAuthenticatedPacketTraversal.Fire()
peer.event.anyAuthenticatedPacketReceived.Fire()
peer.KeepKeyFreshReceiving()
// check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock()
if kp.next == elem.keyPair {
peer.event.handshakeCompleted.Fire()
if kp.previous != nil {
device.DeleteKeyPair(kp.previous)
}
kp.previous = kp.current
kp.current = kp.next
kp.next = nil
}
kp.mutex.Unlock()
// update endpoint
peer.mutex.Lock()
peer.endpoint = elem.endpoint
peer.mutex.Unlock()
// check for keep-alive
// check if using new key-pair
kp := &peer.keyPairs
kp.mutex.Lock() //TODO: make this into an RW lock to reduce contention here for the equality check which is rarely true
if kp.next == elem.keyPair {
old := kp.previous
kp.previous = kp.current
device.DeleteKeypair(old)
kp.current = kp.next
kp.next = nil
peer.timersHandshakeComplete()
select {
case peer.signals.newKeypairArrived <- struct{}{}:
default:
}
}
kp.mutex.Unlock()
peer.keepKeyFreshReceiving()
peer.timersAnyAuthenticatedPacketTraversal()
peer.timersAnyAuthenticatedPacketReceived()
// check for keepalive
if len(elem.packet) == 0 {
logDebug.Println(peer, ": Received keep-alive")
logDebug.Println(peer, ": Receiving keepalive packet")
continue
}
peer.event.dataReceived.Fire()
peer.timersDataReceived()
// verify source and strip padding