diff --git a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java index a1a4b4be282..bd8825474fa 100644 --- a/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java +++ b/util/src/main/java/io/grpc/util/OutlierDetectionLoadBalancer.java @@ -394,47 +394,55 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { Subchannel subchannel = pickResult.getSubchannel(); if (subchannel != null) { - return PickResult.withSubchannel(subchannel, - new ResultCountingClientStreamTracerFactory( - subchannel.getAttributes().get(ADDRESS_TRACKER_ATTR_KEY))); + return PickResult.withSubchannel(subchannel, new ResultCountingClientStreamTracerFactory( + subchannel.getAttributes().get(ADDRESS_TRACKER_ATTR_KEY), + pickResult.getStreamTracerFactory())); } return pickResult; } /** - * Builds instances of {@link ResultCountingClientStreamTracer}. + * Builds instances of a {@link ClientStreamTracer} that increments the call count in the + * tracker for each closed stream. */ class ResultCountingClientStreamTracerFactory extends ClientStreamTracer.Factory { private final AddressTracker tracker; - ResultCountingClientStreamTracerFactory(AddressTracker tracker) { - this.tracker = tracker; - } - - @Override - public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { - return new ResultCountingClientStreamTracer(tracker); - } - } + @Nullable + private final ClientStreamTracer.Factory delegateFactory; - /** - * Counts the results (successful/unsuccessful) of a particular {@link - * OutlierDetectionSubchannel}s streams and increments the counter in the associated {@link - * AddressTracker}. - */ - class ResultCountingClientStreamTracer extends ClientStreamTracer { - - AddressTracker tracker; - - public ResultCountingClientStreamTracer(AddressTracker tracker) { + ResultCountingClientStreamTracerFactory(AddressTracker tracker, + @Nullable ClientStreamTracer.Factory delegateFactory) { this.tracker = tracker; + this.delegateFactory = delegateFactory; } @Override - public void streamClosed(Status status) { - tracker.incrementCallCount(status.isOk()); + public ClientStreamTracer newClientStreamTracer(StreamInfo info, Metadata headers) { + if (delegateFactory != null) { + ClientStreamTracer delegateTracer = delegateFactory.newClientStreamTracer(info, headers); + return new ForwardingClientStreamTracer() { + @Override + protected ClientStreamTracer delegate() { + return delegateTracer; + } + + @Override + public void streamClosed(Status status) { + tracker.incrementCallCount(status.isOk()); + delegate().streamClosed(status); + } + }; + } else { + return new ClientStreamTracer() { + @Override + public void streamClosed(Status status) { + tracker.incrementCallCount(status.isOk()); + } + }; + } } } } diff --git a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java index 18f9bbf549f..13f13421a1e 100644 --- a/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java +++ b/util/src/test/java/io/grpc/util/OutlierDetectionLoadBalancerTest.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; import com.google.common.collect.ImmutableList; @@ -46,6 +47,7 @@ import io.grpc.LoadBalancer.SubchannelPicker; import io.grpc.LoadBalancer.SubchannelStateListener; import io.grpc.LoadBalancerProvider; +import io.grpc.Metadata; import io.grpc.Status; import io.grpc.SynchronizationContext; import io.grpc.internal.FakeClock; @@ -96,6 +98,10 @@ public class OutlierDetectionLoadBalancerTest { private Helper mockHelper; @Mock private SocketAddress mockSocketAddress; + @Mock + private ClientStreamTracer.Factory mockStreamTracerFactory; + @Mock + private ClientStreamTracer mockStreamTracer; @Captor private ArgumentCaptor connectivityStateCaptor; @@ -193,6 +199,9 @@ public Void answer(InvocationOnMock invocation) throws Throwable { } }); + when(mockStreamTracerFactory.newClientStreamTracer(any(), + any())).thenReturn(mockStreamTracer); + loadBalancer = new OutlierDetectionLoadBalancer(mockHelper, fakeClock.getTimeProvider()); } @@ -355,6 +364,72 @@ public void delegatePick() throws Exception { readySubchannel); } + /** + * Any ClientStreamTracer.Factory set by the delegate picker should still get used. + */ + @Test + public void delegatePickTracerFactoryPreserved() throws Exception { + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); + + // Make one of the subchannels READY. + final Subchannel readySubchannel = subchannels.values().iterator().next(); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + + verify(mockHelper, times(2)).updateBalancingState(stateCaptor.capture(), + pickerCaptor.capture()); + + // Make sure that we can pick the single READY subchannel. + SubchannelPicker picker = pickerCaptor.getAllValues().get(1); + PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); + + // Calls to a stream tracer created with the factory in the result should make it to a stream + // tracer the underlying LB/picker is using. + ClientStreamTracer clientStreamTracer = pickResult.getStreamTracerFactory() + .newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); + clientStreamTracer.inboundHeaders(); + // The underlying fake LB provider is configured with a factory that returns a mock stream + // tracer. + verify(mockStreamTracer).inboundHeaders(); + } + + /** + * Assure the tracer works even when the underlying LB does not have a tracer to delegate to. + */ + @Test + public void delegatePickTracerFactoryNotSet() throws Exception { + // We set the mock factory to null to indicate that the delegate does not have its own tracer. + mockStreamTracerFactory = null; + + OutlierDetectionLoadBalancerConfig config = new OutlierDetectionLoadBalancerConfig.Builder() + .setSuccessRateEjection(new SuccessRateEjection.Builder().build()) + .setChildPolicy(new PolicySelection(fakeLbProvider, null)).build(); + + loadBalancer.acceptResolvedAddresses(buildResolvedAddress(config, servers.get(0))); + + // Make one of the subchannels READY. + final Subchannel readySubchannel = subchannels.values().iterator().next(); + deliverSubchannelState(readySubchannel, ConnectivityStateInfo.forNonError(READY)); + + verify(mockHelper, times(2)).updateBalancingState(stateCaptor.capture(), + pickerCaptor.capture()); + + // Make sure that we can pick the single READY subchannel. + SubchannelPicker picker = pickerCaptor.getAllValues().get(1); + PickResult pickResult = picker.pickSubchannel(mock(PickSubchannelArgs.class)); + + // With no delegate tracers factory a call to the OD tracer should still work + ClientStreamTracer clientStreamTracer = pickResult.getStreamTracerFactory() + .newClientStreamTracer(ClientStreamTracer.StreamInfo.newBuilder().build(), new Metadata()); + clientStreamTracer.inboundHeaders(); + + // Sanity check to make sure the delegate tracer does not get called. + verifyNoInteractions(mockStreamTracer); + } + /** * The success rate algorithm leaves a healthy set of addresses alone. */ @@ -1121,7 +1196,7 @@ void assertEjectedSubchannels(Set addresses) { } /** Round robin like fake load balancer. */ - private static final class FakeLoadBalancer extends LoadBalancer { + private final class FakeLoadBalancer extends LoadBalancer { private final Helper helper; List subchannelList; @@ -1159,7 +1234,8 @@ public PickResult pickSubchannel(PickSubchannelArgs args) { if (lastPickIndex < 0 || lastPickIndex > subchannelList.size() - 1) { lastPickIndex = 0; } - return PickResult.withSubchannel(subchannelList.get(lastPickIndex++)); + return PickResult.withSubchannel(subchannelList.get(lastPickIndex++), + mockStreamTracerFactory); } }; helper.updateBalancingState(state, picker);