Skip to content

Commit

Permalink
feat: Decouples Diagnostics tracing from Google's trace header
Browse files Browse the repository at this point in the history
Towards #5360 and #5897
  • Loading branch information
amanda-tarafa committed May 27, 2021
1 parent 553ab40 commit b35b9ea
Show file tree
Hide file tree
Showing 19 changed files with 851 additions and 223 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,8 @@ public static IWebHostBuilder GetHostBuilder<TStartup>(Action<IWebHostBuilder> c

public static TestServer GetTestServer(IWebHostBuilder hostBuilder) => new TestServer(hostBuilder);

public static TestServer GetTestServer<TStartup>() where TStartup : class =>
GetTestServer(GetHostBuilder<TStartup>());
public static TestServer GetTestServer<TStartup>(Action<IWebHostBuilder> configure = null) where TStartup : class =>
GetTestServer(GetHostBuilder<TStartup>(configure));

public static IServiceProvider GetServices(TestServer server) => server.Host.Services;

Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
using Google.Cloud.Diagnostics.Common.IntegrationTests;
using Google.Protobuf.WellKnownTypes;
using Microsoft.AspNetCore.Builder;
using Microsoft.AspNetCore.Http;
using Microsoft.AspNetCore.Mvc;
using Microsoft.AspNetCore.TestHost;
using Microsoft.Extensions.DependencyInjection;
Expand Down Expand Up @@ -151,6 +152,31 @@ public async Task Traces_OutgoingClientFactory()
Assert.False(response.Headers.Contains(TraceHeaderContext.TraceHeader));
}

[Fact]
public async Task Traces_CustomTraceContext()
{
var uri = $"/TraceSamples/{nameof(TraceSamplesController.TraceHelloWorld)}/{_testId}";
var traceId = TraceIdFactory.Create().NextId();

using var server = GetTestServer<CustomTraceContextTestApplication.Startup>();
using var client = server.CreateClient();
var request = new HttpRequestMessage(HttpMethod.Get, uri)
{
Headers = { { "custom_trace_id", traceId } }
};
var response = await client.SendAsync(request);

var trace = s_polling.GetTrace(uri, _startTime);

TraceEntryVerifiers.AssertParentChildSpan(trace, uri, _testId);
TraceEntryVerifiers.AssertSpanLabelsContains(
trace.Spans.First(s => s.Name == uri), TraceEntryData.HttpGetSuccessLabels);
Assert.False(response.Headers.Contains(TraceHeaderContext.TraceHeader));

Assert.Equal(traceId, trace.TraceId);
Assert.True(response.Headers.Contains("custom_trace_id"));
}

private void Troubleshooting()
{
// Not a test - just a sample.
Expand Down Expand Up @@ -201,9 +227,6 @@ public override void Configure(IApplicationBuilder app, ILoggerFactory loggerFac
// Sample: RegisterGoogleTracer
public void ConfigureServices(IServiceCollection services)
{
// The line below is needed for trace ids to be added to logs.
services.AddHttpContextAccessor();

// Replace ProjectId with your Google Cloud Project ID.
services.AddGoogleTrace(options =>
{
Expand Down Expand Up @@ -273,15 +296,111 @@ public override void Configure(IApplicationBuilder app, ILoggerFactory loggerFac
// Sample: ConfigureHttpClient
public void ConfigureServices(IServiceCollection services)
{
// The line below is needed for trace ids to be added to logs.
services.AddHttpContextAccessor();
// Replace ProjectId with your Google Cloud Project ID.
services.AddGoogleTrace(options =>
{
options.ProjectId = ProjectId;
});

// Register an HttpClient for outgoing requests.
services.AddHttpClient("tracesOutgoing")
// The next call guarantees that trace information is propagated for outgoing
// requests that are already being traced.
.AddOutgoingGoogleTraceHandler();

// Add any other services your application requires, for instance,
// depending on the version of ASP.NET Core you are using, you may
// need one of the following:

// services.AddMvc();

// services.AddControllersWithViews();
}

public void Configure(IApplicationBuilder app)
{
// Use at the start of the request pipeline to ensure the entire request is traced.
app.UseGoogleTrace();

// Add any other configuration your application requires, for instance,
// depending on the verson of ASP.NET Core you are using, you may
// need one of the following:

//app.UseMvc(routes =>
//{
// routes.MapRoute(
// name: "default",
// template: "{controller=Home}/{action=Index}/{id?}");
//});

//app.UseRouting();
//app.UseEndpoints(endpoints =>
//{
// endpoints.MapControllerRoute(
// name: "default",
// pattern: "{controller=Home}/{action=Index}/{id?}");
// endpoints.MapRazorPages();
//});
}
// End sample
}

internal class CustomTraceContextTestApplication
{
private static readonly string ProjectId = TestEnvironment.GetTestProjectId();

// To hide some implementation details from the
// sample code, like how we are overriding the methods.
internal class Startup : BaseStartup
{
private readonly CustomTraceContextTestApplication application = new CustomTraceContextTestApplication();

public override void ConfigureServices(IServiceCollection services)
{
application.ConfigureServices(services);
base.ConfigureServices(services);
}

public override void Configure(IApplicationBuilder app, ILoggerFactory loggerFactory)
{
application.Configure(app);
base.Configure(app, loggerFactory);
}
}

// Sample: CustomTraceContext
public void ConfigureServices(IServiceCollection services)
{
// Register a trace context provider method that inspects the request and
// extracts the trace context information.
services.AddScoped(CustomTraceContextProvider);
static ITraceContext CustomTraceContextProvider(IServiceProvider sp)
{
var accessor = sp.GetRequiredService<IHttpContextAccessor>();
string traceId = accessor.HttpContext?.Request?.Headers["custom_trace_id"];
return new SimpleTraceContext(traceId, null, null);
}

// Register a method that sets the updated trace context information on the response.
services.AddSingleton<Action<HttpResponse, ITraceContext>>(
(response, traceContext) => response.Headers.Add("custom_trace_id", traceContext.TraceId));

// Now you can register Google Trace normally.

// Replace ProjectId with your Google Cloud Project ID.
services.AddGoogleTrace(options =>
{
options.ProjectId = ProjectId;
});

// If your application is making outgoing HTTP requests then you configure
// your HTTP client for trace propagation as you normally would.
// If the trace context information should be propagated in a custom format
// then you register a method that sets the trace context information on the
// outgoing request.
services.AddSingleton<Action<HttpRequestMessage, ITraceContext>>(
(request, traceContext) => request.Headers.Add("custom_trace_id", traceContext.TraceId));

// Register an HttpClient for outgoing requests.
services.AddHttpClient("tracesOutgoing")
// The next call guarantees that trace information is propagated for outgoing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ private IServiceProvider CreateProviderForTraceHeaderContext(string traceHeader)
var request = context.Request;
request.Headers[TraceHeaderContext.TraceHeader] = traceHeader;

var accessor = new HttpContextAccessor();
accessor.HttpContext = context;
var accessor = new HttpContextAccessor { HttpContext = context };

var traceIdFactory = TraceIdFactory.Create();

Expand All @@ -63,16 +62,16 @@ public void CreateTraceHeaderContext()
{
var header = $"{_traceId}/{_spanId};o=1";
var provider = CreateProviderForTraceHeaderContext(header);
var headerContext = CloudTraceExtension.CreateTraceHeaderContext(provider);
var headerContext = CloudTraceExtension.ProvideGoogleTraceHeaderContext(provider);
Assert.Equal(TraceHeaderContext.FromHeader(header).ToString(), headerContext.ToString());
}

[Fact]
public void CreateTraceHeaderContext_UseBackUpFunc()
public void CreateTraceHeaderContext_UseShouldTraceFallback()
{
var header = $"{_traceId}/{_spanId};";
var provider = CreateProviderForTraceHeaderContext(header);
var headerContext = CloudTraceExtension.CreateTraceHeaderContext(provider);
var headerContext = CloudTraceExtension.ProvideGoogleTraceHeaderContext(provider);
Assert.Equal(TraceHeaderContext.FromHeader(header).ToString(), headerContext.ToString());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ namespace Google.Cloud.Diagnostics.AspNetCore.Tests
public class CloudTraceMiddlewareTest
{
private static readonly TraceIdFactory _traceIdFactory = TraceIdFactory.Create();
private static readonly TraceHeaderContext _traceHeaderContext =
new TraceHeaderContext(_traceIdFactory.NextId(), 0, true);
private static readonly ITraceContext _traceContext = new SimpleTraceContext(_traceIdFactory.NextId(), 0, true);

/// <summary>
/// Creates a <see cref="Mock{IManagedTracer}"/> that is set up to start and end a span as well as
Expand All @@ -42,7 +41,7 @@ public class CloudTraceMiddlewareTest
private static Mock<IManagedTracer> CreateIManagedTracerMock(HttpContext context)
{
var tracerMock = new Mock<IManagedTracer>();
tracerMock.Setup(t => t.GetCurrentTraceId()).Returns(_traceHeaderContext.TraceId);
tracerMock.Setup(t => t.GetCurrentTraceId()).Returns(_traceContext.TraceId);
tracerMock.Setup(t => t.StartSpan(context.Request.Path, null)).Returns(new NullManagedTracer.Span());
tracerMock.Setup(t => t.AnnotateSpan(It.IsAny<Dictionary<string, string>>()));
return tracerMock;
Expand All @@ -59,6 +58,24 @@ private static HttpContext CreateHttpContext()
return context;
}

private static void CustomTraceContextPropagator(HttpResponse response, ITraceContext context) =>
response.Headers.Add("custom_trace", context.TraceId);

private static void AssertCustomTraceContext(HttpResponse response)
{
// Let's make sure that we don't add the Google trace (unless the propagator does).
Assert.False(response.Headers.ContainsKey(TraceHeaderContext.TraceHeader));
Assert.True(response.Headers.ContainsKey("custom_trace"));
Assert.Equal(_traceContext.TraceId, response.Headers["custom_trace"]);
}

private static void AssertNoTraceContext(HttpResponse response)
{
// Let's make sure that we don't add the Google trace (unless the propagator does).
Assert.False(response.Headers.ContainsKey(TraceHeaderContext.TraceHeader));
Assert.False(response.Headers.ContainsKey("custom_trace"));
}

[Fact]
public async Task Invoke_Trace()
{
Expand All @@ -68,18 +85,17 @@ public async Task Invoke_Trace()
var delegateMock = new Mock<RequestDelegate>();
delegateMock.Setup(d => d(context)).Returns(Task.CompletedTask);

Func<TraceHeaderContext, IManagedTracer> fakeFactory = f => tracerMock.Object;
Func<ITraceContext, IManagedTracer> fakeFactory = f => tracerMock.Object;

Assert.Equal(NullManagedTracer.Instance, ContextTracerManager.GetCurrentTracer());

var middleware = new CloudTraceMiddleware(delegateMock.Object, fakeFactory, new DefaultCloudTraceNameProvider());
await middleware.Invoke(context, _traceHeaderContext);
await middleware.Invoke(context, _traceContext, TraceDecisionPredicate.Default, CustomTraceContextPropagator);

// Since the current tracer is AsyncLocal<>, it will be back to the default after awaiting the middleware invoke
Assert.Equal(NullManagedTracer.Instance, ContextTracerManager.GetCurrentTracer());

Assert.True(context.Response.Headers.ContainsKey(TraceHeaderContext.TraceHeader));
Assert.Equal(_traceHeaderContext.ToString(), context.Response.Headers[TraceHeaderContext.TraceHeader]);
AssertCustomTraceContext(context.Response);

delegateMock.VerifyAll();
tracerMock.VerifyAll();
Expand All @@ -95,20 +111,18 @@ public async Task Invoke_TraceException()
var delegateMock = new Mock<RequestDelegate>();
delegateMock.Setup(d => d(context)).Throws(new DivideByZeroException());

Func<TraceHeaderContext, IManagedTracer> fakeFactory = f => tracerMock.Object;
Func<ITraceContext, IManagedTracer> fakeFactory = f => tracerMock.Object;

var middleware = new CloudTraceMiddleware(delegateMock.Object, fakeFactory, new DefaultCloudTraceNameProvider());
await Assert.ThrowsAsync<DivideByZeroException>(
() => middleware.Invoke(context, _traceHeaderContext));
() => middleware.Invoke(context, _traceContext, TraceDecisionPredicate.Default, CustomTraceContextPropagator));

Assert.True(context.Response.Headers.ContainsKey(TraceHeaderContext.TraceHeader));
Assert.Equal(_traceHeaderContext.ToString(), context.Response.Headers[TraceHeaderContext.TraceHeader]);
AssertCustomTraceContext(context.Response);

delegateMock.VerifyAll();
tracerMock.VerifyAll();
}


[Fact]
public async Task Invoke_TraceThrowsAndException()
{
Expand All @@ -120,14 +134,13 @@ public async Task Invoke_TraceThrowsAndException()
var delegateMock = new Mock<RequestDelegate>();
delegateMock.Setup(d => d(context)).Throws(new DivideByZeroException());

Func<TraceHeaderContext, IManagedTracer> fakeFactory = f => tracerMock.Object;
Func<ITraceContext, IManagedTracer> fakeFactory = f => tracerMock.Object;

var middleware = new CloudTraceMiddleware(delegateMock.Object, fakeFactory, new DefaultCloudTraceNameProvider());
await Assert.ThrowsAsync<AggregateException>(
() => middleware.Invoke(context, _traceHeaderContext));
() => middleware.Invoke(context, _traceContext, TraceDecisionPredicate.Default, CustomTraceContextPropagator));

Assert.True(context.Response.Headers.ContainsKey(TraceHeaderContext.TraceHeader));
Assert.Equal(_traceHeaderContext.ToString(), context.Response.Headers[TraceHeaderContext.TraceHeader]);
AssertCustomTraceContext(context.Response);

delegateMock.VerifyAll();
tracerMock.VerifyAll();
Expand All @@ -140,15 +153,15 @@ public async Task Invoke_NoTrace()
var delegateMock = new Mock<RequestDelegate>();
var tracerMock = new Mock<IManagedTracer>();

Func<TraceHeaderContext, IManagedTracer> fakeFactory = f => tracerMock.Object;
Func<ITraceContext, IManagedTracer> fakeFactory = f => tracerMock.Object;

var middleware = new CloudTraceMiddleware(delegateMock.Object, fakeFactory, new DefaultCloudTraceNameProvider());
await middleware.Invoke(context, _traceHeaderContext);
await middleware.Invoke(context, _traceContext, TraceDecisionPredicate.Default, CustomTraceContextPropagator);

// Since the current tracer is AsyncLocal<>, it will be back to the default after awaiting the middleware invoke
Assert.Equal(NullManagedTracer.Instance, ContextTracerManager.GetCurrentTracer());

Assert.False(context.Response.Headers.ContainsKey(TraceHeaderContext.TraceHeader));
AssertNoTraceContext(context.Response);

delegateMock.Verify(d => d(context), Times.Once());
tracerMock.Verify(t => t.StartSpan(It.IsAny<string>(), null), Times.Never());
Expand Down
Loading

0 comments on commit b35b9ea

Please sign in to comment.