Skip to content

Commit

Permalink
Expose correct Context.current() in reactive-netty callbacks (#2850)
Browse files Browse the repository at this point in the history
  • Loading branch information
Mateusz Rzeszutek authored Apr 29, 2021
1 parent 0c7a20d commit 56d7fd3
Show file tree
Hide file tree
Showing 8 changed files with 633 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http.HttpRequest;
import io.netty.util.Attribute;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.netty.v4_1.AttributeKeys;
Expand All @@ -35,14 +36,19 @@ public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise prm) {
}

Context context = tracer().startSpan(parentContext, ctx, (HttpRequest) msg);
ctx.channel().attr(AttributeKeys.CLIENT_CONTEXT).set(context);
ctx.channel().attr(AttributeKeys.CLIENT_PARENT_CONTEXT).set(parentContext);

Attribute<Context> clientContextAttr = ctx.channel().attr(AttributeKeys.CLIENT_CONTEXT);
Attribute<Context> parentContextAttr = ctx.channel().attr(AttributeKeys.CLIENT_PARENT_CONTEXT);
clientContextAttr.set(context);
parentContextAttr.set(parentContext);

try (Scope ignored = context.makeCurrent()) {
ctx.write(msg, prm);
// span is ended normally in HttpClientResponseTracingHandler
} catch (Throwable throwable) {
tracer().endExceptionally(context, throwable);
clientContextAttr.remove();
parentContextAttr.remove();
throw throwable;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,26 +25,32 @@ public class HttpClientResponseTracingHandler extends ChannelInboundHandlerAdapt

@Override
public void channelRead(ChannelHandlerContext ctx, Object msg) {
Context context = ctx.channel().attr(AttributeKeys.CLIENT_CONTEXT).get();
Attribute<Context> clientContextAttr = ctx.channel().attr(AttributeKeys.CLIENT_CONTEXT);
Context context = clientContextAttr.get();
if (context == null) {
ctx.fireChannelRead(msg);
return;
}

Attribute<Context> parentContextAttr = ctx.channel().attr(AttributeKeys.CLIENT_PARENT_CONTEXT);
Context parentContext = parentContextAttr.get();

if (msg instanceof FullHttpResponse) {
tracer().end(context, (HttpResponse) msg);
clientContextAttr.remove();
parentContextAttr.remove();
} else if (msg instanceof HttpResponse) {
// Headers before body have been received, store them to use when finishing the span.
ctx.channel().attr(HTTP_RESPONSE).set((HttpResponse) msg);
} else if (msg instanceof LastHttpContent) {
// Not a FullHttpResponse so this is content that has been received after headers. Finish the
// span using what we stored in attrs.
tracer().end(context, ctx.channel().attr(HTTP_RESPONSE).get());
tracer().end(context, ctx.channel().attr(HTTP_RESPONSE).getAndRemove());
clientContextAttr.remove();
parentContextAttr.remove();
}

// We want the callback in the scope of the parent, not the client span
Attribute<Context> parentAttr = ctx.channel().attr(AttributeKeys.CLIENT_PARENT_CONTEXT);
Context parentContext = parentAttr.get();
if (parentContext != null) {
try (Scope ignored = parentContext.makeCurrent()) {
ctx.fireChannelRead(msg);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright The OpenTelemetry Authors
* SPDX-License-Identifier: Apache-2.0
*/

package io.opentelemetry.javaagent.instrumentation.reactornetty.v0_9;

import io.netty.channel.Channel;
import io.opentelemetry.context.Context;
import io.opentelemetry.context.Scope;
import io.opentelemetry.instrumentation.netty.v4_1.AttributeKeys;
import java.util.function.BiConsumer;
import org.checkerframework.checker.nullness.qual.Nullable;
import reactor.netty.Connection;
import reactor.netty.http.client.HttpClientRequest;
import reactor.netty.http.client.HttpClientResponse;

public final class DecoratorFunctions {

// ignore our own callbacks - or already decorated functions
public static boolean shouldDecorate(Class<?> callbackClass) {
return !callbackClass.getName().startsWith("io.opentelemetry.javaagent");
}

private abstract static class OnMessageDecorator<M> implements BiConsumer<M, Connection> {
private final BiConsumer<? super M, ? super Connection> delegate;
private final boolean forceParentContext;

public OnMessageDecorator(
BiConsumer<? super M, ? super Connection> delegate, boolean forceParentContext) {
this.delegate = delegate;
this.forceParentContext = forceParentContext;
}

@Override
public final void accept(M message, Connection connection) {
Channel channel = connection.channel();
// don't try to get the client span from the netty channel when forceParentSpan is true
// this way the parent context will always be propagated
if (forceParentContext) {
channel = null;
}
Context context = getChannelContext(currentContext(message), channel);
if (context == null) {
delegate.accept(message, connection);
} else {
try (Scope ignored = context.makeCurrent()) {
delegate.accept(message, connection);
}
}
}

abstract reactor.util.context.Context currentContext(M message);
}

public static final class OnRequestDecorator extends OnMessageDecorator<HttpClientRequest> {
public OnRequestDecorator(BiConsumer<? super HttpClientRequest, ? super Connection> delegate) {
super(delegate, false);
}

@Override
reactor.util.context.Context currentContext(HttpClientRequest message) {
return message.currentContext();
}
}

public static final class OnResponseDecorator extends OnMessageDecorator<HttpClientResponse> {
public OnResponseDecorator(
BiConsumer<? super HttpClientResponse, ? super Connection> delegate,
boolean forceParentContext) {
super(delegate, forceParentContext);
}

@Override
reactor.util.context.Context currentContext(HttpClientResponse message) {
return message.currentContext();
}
}

private abstract static class OnMessageErrorDecorator<M> implements BiConsumer<M, Throwable> {
private final BiConsumer<? super M, ? super Throwable> delegate;

public OnMessageErrorDecorator(BiConsumer<? super M, ? super Throwable> delegate) {
this.delegate = delegate;
}

@Override
public final void accept(M message, Throwable throwable) {
Context context = getChannelContext(currentContext(message), null);
if (context == null) {
delegate.accept(message, throwable);
} else {
try (Scope ignored = context.makeCurrent()) {
delegate.accept(message, throwable);
}
}
}

abstract reactor.util.context.Context currentContext(M message);
}

public static final class OnRequestErrorDecorator
extends OnMessageErrorDecorator<HttpClientRequest> {
public OnRequestErrorDecorator(
BiConsumer<? super HttpClientRequest, ? super Throwable> delegate) {
super(delegate);
}

@Override
reactor.util.context.Context currentContext(HttpClientRequest message) {
return message.currentContext();
}
}

public static final class OnResponseErrorDecorator
extends OnMessageErrorDecorator<HttpClientResponse> {
public OnResponseErrorDecorator(
BiConsumer<? super HttpClientResponse, ? super Throwable> delegate) {
super(delegate);
}

@Override
reactor.util.context.Context currentContext(HttpClientResponse message) {
return message.currentContext();
}
}

@Nullable
private static Context getChannelContext(
reactor.util.context.Context reactorContext, @Nullable Channel channel) {
// try to get the client span context from the channel if it's available
if (channel != null) {
Context context = channel.attr(AttributeKeys.CLIENT_CONTEXT).get();
if (context != null) {
return context;
}
}
// otherwise use the parent span context
return reactorContext.getOrDefault(
ReactorNettyInstrumentationModule.MapConnect.CONTEXT_ATTRIBUTE, null);
}

private DecoratorFunctions() {}
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import static io.opentelemetry.javaagent.tooling.bytebuddy.matcher.ClassLoaderMatcher.hasClassesNamed;
import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static net.bytebuddy.matcher.ElementMatchers.isPublic;
import static net.bytebuddy.matcher.ElementMatchers.isStatic;
import static net.bytebuddy.matcher.ElementMatchers.named;
import static net.bytebuddy.matcher.ElementMatchers.namedOneOf;
import static net.bytebuddy.matcher.ElementMatchers.takesArgument;
import static net.bytebuddy.matcher.ElementMatchers.takesArguments;

import com.google.auto.service.AutoService;
import io.netty.bootstrap.Bootstrap;
Expand All @@ -19,6 +21,7 @@
import io.opentelemetry.javaagent.instrumentation.api.CallDepthThreadLocalMap;
import io.opentelemetry.javaagent.tooling.InstrumentationModule;
import io.opentelemetry.javaagent.tooling.TypeInstrumentation;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.BiConsumer;
Expand All @@ -31,6 +34,7 @@
import reactor.netty.Connection;
import reactor.netty.http.client.HttpClient;
import reactor.netty.http.client.HttpClientRequest;
import reactor.netty.http.client.HttpClientResponse;

/**
* This instrumentation solves the problem of the correct context propagation through the roller
Expand Down Expand Up @@ -64,9 +68,44 @@ public ElementMatcher<TypeDescription> typeMatcher() {

@Override
public Map<? extends ElementMatcher<? super MethodDescription>, String> transformers() {
return singletonMap(
Map<ElementMatcher.Junction<MethodDescription>, String> transformers = new HashMap<>();
transformers.put(
isStatic().and(namedOneOf("create", "newConnection", "from")),
ReactorNettyInstrumentationModule.class.getName() + "$CreateAdvice");

// advice classes below expose current context in doOn*/doAfter* callbacks
transformers.put(
isPublic()
.and(namedOneOf("doOnRequest", "doAfterRequest"))
.and(takesArguments(1))
.and(takesArgument(0, BiConsumer.class)),
ReactorNettyInstrumentationModule.class.getName() + "$OnRequestAdvice");
transformers.put(
isPublic()
.and(named("doOnRequestError"))
.and(takesArguments(1))
.and(takesArgument(0, BiConsumer.class)),
ReactorNettyInstrumentationModule.class.getName() + "$OnRequestErrorAdvice");
transformers.put(
isPublic()
.and(namedOneOf("doOnResponse", "doAfterResponse"))
.and(takesArguments(1))
.and(takesArgument(0, BiConsumer.class)),
ReactorNettyInstrumentationModule.class.getName() + "$OnResponseAdvice");
transformers.put(
isPublic()
.and(named("doOnResponseError"))
.and(takesArguments(1))
.and(takesArgument(0, BiConsumer.class)),
ReactorNettyInstrumentationModule.class.getName() + "$OnResponseErrorAdvice");
transformers.put(
isPublic()
.and(named("doOnError"))
.and(takesArguments(2))
.and(takesArgument(0, BiConsumer.class))
.and(takesArgument(1, BiConsumer.class)),
ReactorNettyInstrumentationModule.class.getName() + "$OnErrorAdvice");
return transformers;
}
}

Expand Down Expand Up @@ -105,4 +144,66 @@ public void accept(HttpClientRequest r, Connection c) {
c.channel().attr(AttributeKeys.WRITE_CONTEXT).set(context);
}
}

public static class OnRequestAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static void onEnter(
@Advice.Argument(value = 0, readOnly = false)
BiConsumer<? super HttpClientRequest, ? super Connection> callback) {
if (DecoratorFunctions.shouldDecorate(callback.getClass())) {
callback = new DecoratorFunctions.OnRequestDecorator(callback);
}
}
}

public static class OnRequestErrorAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static void onEnter(
@Advice.Argument(value = 0, readOnly = false)
BiConsumer<? super HttpClientRequest, ? super Throwable> callback) {
if (DecoratorFunctions.shouldDecorate(callback.getClass())) {
callback = new DecoratorFunctions.OnRequestErrorDecorator(callback);
}
}
}

public static class OnResponseAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static void onEnter(
@Advice.Argument(value = 0, readOnly = false)
BiConsumer<? super HttpClientResponse, ? super Connection> callback,
@Advice.Origin("#m") String methodName) {
if (DecoratorFunctions.shouldDecorate(callback.getClass())) {
boolean forceParentContext = methodName.equals("doAfterResponse");
callback = new DecoratorFunctions.OnResponseDecorator(callback, forceParentContext);
}
}
}

public static class OnResponseErrorAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static void onEnter(
@Advice.Argument(value = 0, readOnly = false)
BiConsumer<? super HttpClientResponse, ? super Throwable> callback) {
if (DecoratorFunctions.shouldDecorate(callback.getClass())) {
callback = new DecoratorFunctions.OnResponseErrorDecorator(callback);
}
}
}

public static class OnErrorAdvice {
@Advice.OnMethodEnter(suppress = Throwable.class)
public static void onEnter(
@Advice.Argument(value = 0, readOnly = false)
BiConsumer<? super HttpClientRequest, ? super Throwable> requestCallback,
@Advice.Argument(value = 1, readOnly = false)
BiConsumer<? super HttpClientResponse, ? super Throwable> responseCallback) {
if (DecoratorFunctions.shouldDecorate(requestCallback.getClass())) {
requestCallback = new DecoratorFunctions.OnRequestErrorDecorator(requestCallback);
}
if (DecoratorFunctions.shouldDecorate(responseCallback.getClass())) {
responseCallback = new DecoratorFunctions.OnResponseErrorDecorator(responseCallback);
}
}
}
}
Loading

0 comments on commit 56d7fd3

Please sign in to comment.