From c2be9921fe46b336b1991c4e5984666bd75b941f Mon Sep 17 00:00:00 2001 From: Damien Neil Date: Thu, 12 Dec 2024 13:49:57 -0800 Subject: [PATCH] quic: remember which remote connection IDs have been retired Keep track of which peer-provided connection ID sequence numbers we have sent a RETIRE_CONNECTION_ID for, received an ack, and removed from our set of tracked IDs. Rework how we track retired connection IDs in general: Rather than keeping all state for retired IDs, just keep a set of the sequence numbers of IDs which we're in the process of retiring. Change-Id: I14da8b5295d5fbe8318c8afe556cbd2c8a56d856 Reviewed-on: https://go-review.googlesource.com/c/net/+/635717 LUCI-TryBot-Result: Go LUCI Reviewed-by: Jonathan Amsterdam --- quic/conn_id.go | 129 +++++++++++++++++++++---------------------- quic/conn_id_test.go | 49 ++++++++++++++++ quic/rangeset.go | 8 +++ 3 files changed, 119 insertions(+), 67 deletions(-) diff --git a/quic/conn_id.go b/quic/conn_id.go index 2efe8d6b5..2d50f14fa 100644 --- a/quic/conn_id.go +++ b/quic/conn_id.go @@ -9,6 +9,7 @@ package quic import ( "bytes" "crypto/rand" + "slices" ) // connIDState is a conn's connection IDs. @@ -25,8 +26,16 @@ type connIDState struct { remote []remoteConnID nextLocalSeq int64 - retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer - peerActiveConnIDLimit int64 // peer's active_connection_id_limit transport parameter + peerActiveConnIDLimit int64 // peer's active_connection_id_limit + + // Handling of retirement of remote connection IDs. + // The rangesets track ID sequence numbers. + // IDs in need of retirement are added to remoteRetiring, + // moved to remoteRetiringSent once we send a RETIRE_CONECTION_ID frame, + // and removed from the set once retirement completes. + retireRemotePriorTo int64 // largest Retire Prior To value sent by the peer + remoteRetiring rangeset[int64] // remote IDs in need of retirement + remoteRetiringSent rangeset[int64] // remote IDs waiting for ack of retirement originalDstConnID []byte // expected original_destination_connection_id param retrySrcConnID []byte // expected retry_source_connection_id param @@ -45,9 +54,6 @@ type connID struct { // For the transient destination ID in a client's Initial packet, this is -1. seq int64 - // retired is set when the connection ID is retired. - retired bool - // send is set when the connection ID's state needs to be sent to the peer. // // For local IDs, this indicates a new ID that should be sent @@ -144,9 +150,7 @@ func (s *connIDState) srcConnID() []byte { // dstConnID is the Destination Connection ID to use in a sent packet. func (s *connIDState) dstConnID() (cid []byte, ok bool) { for i := range s.remote { - if !s.remote[i].retired { - return s.remote[i].cid, true - } + return s.remote[i].cid, true } return nil, false } @@ -154,14 +158,12 @@ func (s *connIDState) dstConnID() (cid []byte, ok bool) { // isValidStatelessResetToken reports whether the given reset token is // associated with a non-retired connection ID which we have used. func (s *connIDState) isValidStatelessResetToken(resetToken statelessResetToken) bool { - for i := range s.remote { - // We currently only use the first available remote connection ID, - // so any other reset token is not valid. - if !s.remote[i].retired { - return s.remote[i].resetToken == resetToken - } + if len(s.remote) == 0 { + return false } - return false + // We currently only use the first available remote connection ID, + // so any other reset token is not valid. + return s.remote[0].resetToken == resetToken } // setPeerActiveConnIDLimit sets the active_connection_id_limit @@ -174,7 +176,7 @@ func (s *connIDState) setPeerActiveConnIDLimit(c *Conn, lim int64) error { func (s *connIDState) issueLocalIDs(c *Conn) error { toIssue := min(int(s.peerActiveConnIDLimit), maxPeerActiveConnIDLimit) for i := range s.local { - if s.local[i].seq != -1 && !s.local[i].retired { + if s.local[i].seq != -1 { toIssue-- } } @@ -271,7 +273,7 @@ func (s *connIDState) handlePacket(c *Conn, ptype packetType, srcConnID []byte) } } case ptype == packetTypeHandshake && c.side == serverSide: - if len(s.local) > 0 && s.local[0].seq == -1 && !s.local[0].retired { + if len(s.local) > 0 && s.local[0].seq == -1 { // We're a server connection processing the first Handshake packet from // the client. Discard the transient, client-chosen connection ID used // for Initial packets; the client will never send it again. @@ -304,23 +306,29 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re } } + if seq < s.retireRemotePriorTo { + // This ID was already retired by a previous NEW_CONNECTION_ID frame. + // Nothing to do. + return nil + } + if retire > s.retireRemotePriorTo { + // Add newly-retired connection IDs to the set we need to send + // RETIRE_CONNECTION_ID frames for, and remove them from s.remote. + // + // (This might cause us to send a RETIRE_CONNECTION_ID for an ID we've + // never seen. That's fine.) + s.remoteRetiring.add(s.retireRemotePriorTo, retire) s.retireRemotePriorTo = retire + s.needSend = true + s.remote = slices.DeleteFunc(s.remote, func(rcid remoteConnID) bool { + return rcid.seq < s.retireRemotePriorTo + }) } have := false // do we already have this connection ID? - active := 0 for i := range s.remote { rcid := &s.remote[i] - if !rcid.retired && rcid.seq >= 0 && rcid.seq < s.retireRemotePriorTo { - s.retireRemote(rcid) - c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { - conns.retireResetToken(c, rcid.resetToken) - }) - } - if !rcid.retired { - active++ - } if rcid.seq == seq { if !bytes.Equal(rcid.cid, cid) { return localTransportError{ @@ -329,6 +337,7 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re } } have = true // yes, we've seen this sequence number + break } } @@ -345,18 +354,12 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re }, resetToken: resetToken, }) - if seq < s.retireRemotePriorTo { - // This ID was already retired by a previous NEW_CONNECTION_ID frame. - s.retireRemote(&s.remote[len(s.remote)-1]) - } else { - active++ - c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { - conns.addResetToken(c, resetToken) - }) - } + c.endpoint.connsMap.updateConnIDs(func(conns *connsMap) { + conns.addResetToken(c, resetToken) + }) } - if active > activeConnIDLimit { + if len(s.remote) > activeConnIDLimit { // Retired connection IDs (including newly-retired ones) do not count // against the limit. // https://www.rfc-editor.org/rfc/rfc9000.html#section-5.1.1-5 @@ -370,25 +373,18 @@ func (s *connIDState) handleNewConnID(c *Conn, seq, retire int64, cid []byte, re // for which RETIRE_CONNECTION_ID frames have not yet been acknowledged." // https://www.rfc-editor.org/rfc/rfc9000#section-5.1.2-6 // - // Set a limit of four times the active_connection_id_limit for - // the total number of remote connection IDs we keep state for locally. - if len(s.remote) > 4*activeConnIDLimit { + // Set a limit of three times the active_connection_id_limit for + // the total number of remote connection IDs we keep retirement state for. + if s.remoteRetiring.size()+s.remoteRetiringSent.size() > 3*activeConnIDLimit { return localTransportError{ code: errConnectionIDLimit, - reason: "too many unacknowledged RETIRE_CONNECTION_ID frames", + reason: "too many unacknowledged retired connection ids", } } return nil } -// retireRemote marks a remote connection ID as retired. -func (s *connIDState) retireRemote(rcid *remoteConnID) { - rcid.retired = true - rcid.send.setUnsent() - s.needSend = true -} - func (s *connIDState) handleRetireConnID(c *Conn, seq int64) error { if seq >= s.nextLocalSeq { return localTransportError{ @@ -424,20 +420,11 @@ func (s *connIDState) ackOrLossNewConnectionID(pnum packetNumber, seq int64, fat } func (s *connIDState) ackOrLossRetireConnectionID(pnum packetNumber, seq int64, fate packetFate) { - for i := 0; i < len(s.remote); i++ { - if s.remote[i].seq != seq { - continue - } - if fate == packetAcked { - // We have retired this connection ID, and the peer has acked. - // Discard its state completely. - s.remote = append(s.remote[:i], s.remote[i+1:]...) - } else { - // RETIRE_CONNECTION_ID frame was lost, mark for retransmission. - s.needSend = true - s.remote[i].send.ackOrLoss(pnum, fate) - } - return + s.remoteRetiringSent.sub(seq, seq+1) + if fate == packetLost { + // RETIRE_CONNECTION_ID frame was lost, mark for retransmission. + s.remoteRetiring.add(seq, seq+1) + s.needSend = true } } @@ -469,14 +456,22 @@ func (s *connIDState) appendFrames(c *Conn, pnum packetNumber, pto bool) bool { } s.local[i].send.setSent(pnum) } - for i := range s.remote { - if !s.remote[i].send.shouldSendPTO(pto) { - continue + if pto { + for _, r := range s.remoteRetiringSent { + for cid := r.start; cid < r.end; cid++ { + if !c.w.appendRetireConnectionIDFrame(cid) { + return false + } + } } - if !c.w.appendRetireConnectionIDFrame(s.remote[i].seq) { + } + for s.remoteRetiring.numRanges() > 0 { + cid := s.remoteRetiring.min() + if !c.w.appendRetireConnectionIDFrame(cid) { return false } - s.remote[i].send.setSent(pnum) + s.remoteRetiring.sub(cid, cid+1) + s.remoteRetiringSent.add(cid, cid+1) } s.needSend = false return true diff --git a/quic/conn_id_test.go b/quic/conn_id_test.go index d44472e81..2c3f17016 100644 --- a/quic/conn_id_test.go +++ b/quic/conn_id_test.go @@ -664,3 +664,52 @@ func TestConnIDsCleanedUpAfterClose(t *testing.T) { } }) } + +func TestConnIDRetiredConnIDResent(t *testing.T) { + tc := newTestConn(t, serverSide) + tc.handshake() + tc.ignoreFrame(frameTypeAck) + //tc.ignoreFrame(frameTypeRetireConnectionID) + + // Send CID 2, retire 0-1 (negotiated during the handshake). + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + retirePriorTo: 2, + connID: testPeerConnID(2), + token: testPeerStatelessResetToken(2), + }) + tc.wantFrame("retire CID 0", packetType1RTT, debugFrameRetireConnectionID{seq: 0}) + tc.wantFrame("retire CID 1", packetType1RTT, debugFrameRetireConnectionID{seq: 1}) + + // Send CID 3, retire 2. + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 3, + retirePriorTo: 3, + connID: testPeerConnID(3), + token: testPeerStatelessResetToken(3), + }) + tc.wantFrame("retire CID 2", packetType1RTT, debugFrameRetireConnectionID{seq: 2}) + + // Acknowledge retirement of CIDs 0-2. + // The server should have state for only one CID: 3. + tc.writeAckForAll() + if got, want := len(tc.conn.connIDState.remote), 1; got != want { + t.Fatalf("connection has state for %v connection IDs, want %v", got, want) + } + + // Send CID 2 again. + // The server should ignore this, since it's already retired the CID. + tc.ignoreFrames[frameTypeRetireConnectionID] = false + tc.writeFrames(packetType1RTT, + debugFrameNewConnectionID{ + seq: 2, + connID: testPeerConnID(2), + token: testPeerStatelessResetToken(2), + }) + if got, want := len(tc.conn.connIDState.remote), 1; got != want { + t.Fatalf("connection has state for %v connection IDs, want %v", got, want) + } + tc.wantIdle("server does not re-retire already retired CID 2") +} diff --git a/quic/rangeset.go b/quic/rangeset.go index b8b2e9367..528d53df3 100644 --- a/quic/rangeset.go +++ b/quic/rangeset.go @@ -159,6 +159,14 @@ func (s rangeset[T]) numRanges() int { return len(s) } +// size returns the size of all ranges in the rangeset. +func (s rangeset[T]) size() (total T) { + for _, r := range s { + total += r.size() + } + return total +} + // isrange reports if the rangeset covers exactly the range [start, end). func (s rangeset[T]) isrange(start, end T) bool { switch len(s) {