Skip to content

Commit

Permalink
pass a context to logging.Tracer.NewConnectionTracer
Browse files Browse the repository at this point in the history
This context has the same value attached to it as the context returned
by Session.Context().
In the case of a dialed connection, this context is derived from the
context used for dialing.
  • Loading branch information
marten-seemann committed Apr 14, 2021
1 parent 4917760 commit e1e38d2
Show file tree
Hide file tree
Showing 10 changed files with 54 additions and 35 deletions.
8 changes: 6 additions & 2 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,13 +203,17 @@ func dialContext(
}
c.packetHandlers = packetHandlers

c.tracingID = nextSessionTracingID()
if c.config.Tracer != nil {
c.tracer = c.config.Tracer.TracerForConnection(protocol.PerspectiveClient, c.destConnID)
c.tracer = c.config.Tracer.TracerForConnection(
context.WithValue(ctx, SessionTracingKey, c.tracingID),
protocol.PerspectiveClient,
c.destConnID,
)
}
if c.tracer != nil {
c.tracer.StartedConnection(c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID)
}
c.tracingID = nextSessionTracingID()
if err := c.dial(ctx); err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ var _ = Describe("Client", func() {
originalClientSessConstructor = newClientSession
tracer = mocklogging.NewMockConnectionTracer(mockCtrl)
tr := mocklogging.NewMockTracer(mockCtrl)
tr.EXPECT().TracerForConnection(protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
tr.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveClient, gomock.Any()).Return(tracer).MaxTimes(1)
config = &Config{Tracer: tr, Versions: []protocol.VersionNumber{protocol.VersionTLS}}
Eventually(areSessionsRunning).Should(BeFalse())
// sess = NewMockQuicSession(mockCtrl)
Expand Down
9 changes: 5 additions & 4 deletions internal/mocks/logging/tracer.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion logging/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package logging

import (
"context"
"net"
"time"

Expand Down Expand Up @@ -95,7 +96,7 @@ type Tracer interface {
// The ODCID is the original destination connection ID:
// The destination connection ID that the client used on the first Initial packet it sent on this connection.
// If nil is returned, tracing will be disabled for this connection.
TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer
TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer

SentPacket(net.Addr, *Header, ByteCount, []Frame)
DroppedPacket(net.Addr, PacketType, ByteCount, PacketDropReason)
Expand Down
9 changes: 5 additions & 4 deletions logging/mock_tracer_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions logging/multiplex.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package logging

import (
"context"
"net"
"time"
)
Expand All @@ -22,10 +23,10 @@ func NewMultiplexedTracer(tracers ...Tracer) Tracer {
return &tracerMultiplexer{tracers}
}

func (m *tracerMultiplexer) TracerForConnection(p Perspective, odcid ConnectionID) ConnectionTracer {
func (m *tracerMultiplexer) TracerForConnection(ctx context.Context, p Perspective, odcid ConnectionID) ConnectionTracer {
var connTracers []ConnectionTracer
for _, t := range m.tracers {
if ct := t.TracerForConnection(p, odcid); ct != nil {
if ct := t.TracerForConnection(ctx, p, odcid); ct != nil {
connTracers = append(connTracers, ct)
}
}
Expand Down
29 changes: 17 additions & 12 deletions logging/multiplex_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package logging

import (
"context"
"net"
"time"

Expand Down Expand Up @@ -35,35 +36,39 @@ var _ = Describe("Tracing", func() {
})

It("multiplexes the TracerForConnection call", func() {
tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
ctx := context.Background()
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
})

It("uses multiple connection tracers", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl)
ctr2 := NewMockConnectionTracer(mockCtrl)
tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2)
tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr2)
tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3})
ctr1.EXPECT().LossTimerCanceled()
ctr2.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled()
})

It("handles tracers that return a nil ConnectionTracer", func() {
ctx := context.Background()
ctr1 := NewMockConnectionTracer(mockCtrl)
tr1.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3})
tr := tracer.TracerForConnection(PerspectiveServer, ConnectionID{1, 2, 3})
tr1.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3}).Return(ctr1)
tr2.EXPECT().TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3})
tr := tracer.TracerForConnection(ctx, PerspectiveServer, ConnectionID{1, 2, 3})
ctr1.EXPECT().LossTimerCanceled()
tr.LossTimerCanceled()
})

It("returns nil when all tracers return a nil ConnectionTracer", func() {
tr1.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})
Expect(tracer.TracerForConnection(PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil())
ctx := context.Background()
tr1.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
tr2.EXPECT().TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})
Expect(tracer.TracerForConnection(ctx, PerspectiveClient, ConnectionID{1, 2, 3})).To(BeNil())
})

It("traces the PacketSent event", func() {
Expand Down
3 changes: 2 additions & 1 deletion qlog/qlog.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package qlog

import (
"bytes"
"context"
"fmt"
"io"
"log"
Expand Down Expand Up @@ -59,7 +60,7 @@ func NewTracer(getLogWriter func(p logging.Perspective, connectionID []byte) io.
return &tracer{getLogWriter: getLogWriter}
}

func (t *tracer) TracerForConnection(p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
func (t *tracer) TracerForConnection(_ context.Context, p logging.Perspective, odcid protocol.ConnectionID) logging.ConnectionTracer {
if w := t.getLogWriter(p, odcid.Bytes()); w != nil {
return NewConnectionTracer(w, p, odcid)
}
Expand Down
9 changes: 7 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -451,6 +451,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
}
s.logger.Debugf("Changing connection ID to %s.", connID)
var sess quicSession
tracingID := nextSessionTracingID()
if added := s.sessionHandler.AddWithConnID(hdr.DestConnectionID, connID, func() packetHandler {
var tracer logging.ConnectionTracer
if s.config.Tracer != nil {
Expand All @@ -459,7 +460,11 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
if origDestConnID.Len() > 0 {
connID = origDestConnID
}
tracer = s.config.Tracer.TracerForConnection(protocol.PerspectiveServer, connID)
tracer = s.config.Tracer.TracerForConnection(
context.WithValue(context.Background(), SessionTracingKey, tracingID),
protocol.PerspectiveServer,
connID,
)
}
sess = s.newSession(
newSendConn(s.conn, p.remoteAddr, p.info),
Expand All @@ -475,7 +480,7 @@ func (s *baseServer) handleInitialImpl(p *receivedPacket, hdr *wire.Header) erro
s.tokenGenerator,
s.acceptEarlySessions,
tracer,
nextSessionTracingID(),
tracingID,
s.logger,
hdr.Version,
)
Expand Down
12 changes: 6 additions & 6 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ var _ = Describe("Server", func() {
fn()
return true
})
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{0xde, 0xad, 0xc0, 0xde})
sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
_ sendConn,
Expand Down Expand Up @@ -579,7 +579,7 @@ var _ = Describe("Server", func() {
fn()
return true
})
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})

sess := NewMockQuicSession(mockCtrl)
serv.newSession = func(
Expand Down Expand Up @@ -637,7 +637,7 @@ var _ = Describe("Server", func() {
fn()
return true
}).AnyTimes()
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()).AnyTimes()
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).AnyTimes()

serv.config.AcceptToken = func(net.Addr, *Token) bool { return true }
acceptSession := make(chan struct{})
Expand Down Expand Up @@ -760,7 +760,7 @@ var _ = Describe("Server", func() {
fn()
return true
}).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any()).Times(protocol.MaxAcceptQueueSize)

var wg sync.WaitGroup
wg.Add(protocol.MaxAcceptQueueSize)
Expand Down Expand Up @@ -832,7 +832,7 @@ var _ = Describe("Server", func() {
fn()
return true
})
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any())
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())

serv.handlePacket(p)
// make sure there are no Write calls on the packet conn
Expand Down Expand Up @@ -940,7 +940,7 @@ var _ = Describe("Server", func() {
fn()
return true
})
tracer.EXPECT().TracerForConnection(protocol.PerspectiveServer, gomock.Any())
tracer.EXPECT().TracerForConnection(gomock.Any(), protocol.PerspectiveServer, gomock.Any())
serv.handleInitialImpl(
&receivedPacket{buffer: getPacketBuffer()},
&wire.Header{DestConnectionID: protocol.ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}},
Expand Down

0 comments on commit e1e38d2

Please sign in to comment.