Skip to content

Commit

Permalink
chore: Clean up resources in gRPCStreamDirectController and ChannelPo…
Browse files Browse the repository at this point in the history
…ol tests (#1691)

* chore: Shutdown test resources

* chore: Clean up resources for the flaky test

* chore: Clean up test

* chore: Shutdown ChannelPool resources

* chore: Use constant for termination seconds

* chore: Clean up exception in runnable

* chore: Increase termination delay

* chore: Clean up resources

* chore: Clean up resources after tests
  • Loading branch information
lqiu96 authored Jun 12, 2023
1 parent dfa9d2b commit 8cbea70
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 82 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import com.google.api.gax.rpc.ServerStreamingCallSettings;
import com.google.api.gax.rpc.ServerStreamingCallable;
import com.google.api.gax.rpc.StreamController;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.type.Color;
Expand All @@ -62,6 +63,7 @@
import java.util.concurrent.ScheduledFuture;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import org.junit.After;
import org.junit.Assert;
import org.junit.Test;
import org.junit.runner.RunWith;
Expand All @@ -72,14 +74,24 @@

@RunWith(JUnit4.class)
public class ChannelPoolTest {
private static final int DEFAULT_AWAIT_TERMINATION_SEC = 10;
private ChannelPool pool;

@After
public void cleanup() throws InterruptedException {
Preconditions.checkNotNull(pool, "Channel pool was never created");
pool.shutdown();
pool.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS);
}

@Test
public void testAuthority() throws IOException {
ManagedChannel sub1 = Mockito.mock(ManagedChannel.class);
ManagedChannel sub2 = Mockito.mock(ManagedChannel.class);

Mockito.when(sub1.authority()).thenReturn("myAuth");

ChannelPool pool =
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(2),
new FakeChannelFactory(Arrays.asList(sub1, sub2)));
Expand All @@ -94,7 +106,7 @@ public void testRoundRobin() throws IOException {
Mockito.when(sub1.authority()).thenReturn("myAuth");

ArrayList<ManagedChannel> channels = Lists.newArrayList(sub1, sub2);
ChannelPool pool =
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(channels.size()), new FakeChannelFactory(channels));

Expand Down Expand Up @@ -150,7 +162,7 @@ public void ensureEvenDistribution() throws InterruptedException, IOException {
});
}

final ChannelPool pool =
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(numChannels),
new FakeChannelFactory(Arrays.asList(channels)));
Expand Down Expand Up @@ -184,12 +196,13 @@ public void channelPrimerShouldCallPoolConstruction() throws IOException {
ManagedChannel channel1 = Mockito.mock(ManagedChannel.class);
ManagedChannel channel2 = Mockito.mock(ManagedChannel.class);

ChannelPool.create(
ChannelPoolSettings.staticallySized(2)
.toBuilder()
.setPreemptiveRefreshEnabled(true)
.build(),
new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer));
pool =
ChannelPool.create(
ChannelPoolSettings.staticallySized(2)
.toBuilder()
.setPreemptiveRefreshEnabled(true)
.build(),
new FakeChannelFactory(Arrays.asList(channel1, channel2), mockChannelPrimer));
Mockito.verify(mockChannelPrimer, Mockito.times(2))
.primeChannel(Mockito.any(ManagedChannel.class));
}
Expand Down Expand Up @@ -221,13 +234,14 @@ public void channelPrimerIsCalledPeriodically() throws IOException {
FakeChannelFactory channelFactory =
new FakeChannelFactory(Arrays.asList(channel1, channel2, channel3), mockChannelPrimer);

new ChannelPool(
ChannelPoolSettings.staticallySized(1)
.toBuilder()
.setPreemptiveRefreshEnabled(true)
.build(),
channelFactory,
scheduledExecutorService);
pool =
new ChannelPool(
ChannelPoolSettings.staticallySized(1)
.toBuilder()
.setPreemptiveRefreshEnabled(true)
.build(),
channelFactory,
scheduledExecutorService);
// 1 call during the creation
Mockito.verify(mockChannelPrimer, Mockito.times(1))
.primeChannel(Mockito.any(ManagedChannel.class));
Expand All @@ -251,7 +265,7 @@ public void callShouldCompleteAfterCreation() throws IOException {
ManagedChannel replacementChannel = Mockito.mock(ManagedChannel.class);
FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel));
ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);

// create a mock call when new call comes to the underlying channel
MockClientCall<String, Integer> mockClientCall = new MockClientCall<>(1, Status.OK);
Expand Down Expand Up @@ -300,7 +314,7 @@ public void callShouldCompleteAfterStarted() throws IOException {

FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel));
ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);

// create a mock call when new call comes to the underlying channel
MockClientCall<String, Integer> mockClientCall = new MockClientCall<>(1, Status.OK);
Expand Down Expand Up @@ -345,7 +359,7 @@ public void channelShouldShutdown() throws IOException {

FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel, replacementChannel));
ChannelPool pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);
pool = ChannelPool.create(ChannelPoolSettings.staticallySized(1), channelFactory);

// create a mock call when new call comes to the underlying channel
MockClientCall<String, Integer> mockClientCall = new MockClientCall<>(1, Status.OK);
Expand Down Expand Up @@ -397,7 +411,7 @@ public void channelRefreshShouldSwapChannels() throws IOException {

FakeChannelFactory channelFactory =
new FakeChannelFactory(ImmutableList.of(underlyingChannel1, underlyingChannel2));
ChannelPool pool =
pool =
new ChannelPool(
ChannelPoolSettings.staticallySized(1)
.toBuilder()
Expand Down Expand Up @@ -444,7 +458,7 @@ public void channelCountShouldNotChangeWhenOutstandingRpcsAreWithinLimits() thro
return channel;
};

ChannelPool pool =
pool =
new ChannelPool(
ChannelPoolSettings.builder()
.setInitialChannelCount(2)
Expand Down Expand Up @@ -525,7 +539,7 @@ public void removedIdleChannelsAreShutdown() throws Exception {
return channel;
};

ChannelPool pool =
pool =
new ChannelPool(
ChannelPoolSettings.builder()
.setInitialChannelCount(2)
Expand Down Expand Up @@ -565,7 +579,7 @@ public void removedActiveChannelsAreShutdown() throws Exception {
return channel;
};

ChannelPool pool =
pool =
new ChannelPool(
ChannelPoolSettings.builder()
.setInitialChannelCount(2)
Expand Down Expand Up @@ -612,11 +626,11 @@ public void testReleasingClientCallCancelEarly() throws IOException {
Mockito.when(fakeChannel.newCall(Mockito.any(), Mockito.any())).thenReturn(mockClientCall);
ChannelPoolSettings channelPoolSettings = ChannelPoolSettings.staticallySized(1);
ChannelFactory factory = new FakeChannelFactory(ImmutableList.of(fakeChannel));
ChannelPool channelPool = ChannelPool.create(channelPoolSettings, factory);
pool = ChannelPool.create(channelPoolSettings, factory);
ClientContext context =
ClientContext.newBuilder()
.setTransportChannel(GrpcTransportChannel.create(channelPool))
.setDefaultCallContext(GrpcCallContext.of(channelPool, CallOptions.DEFAULT))
.setTransportChannel(GrpcTransportChannel.create(pool))
.setDefaultCallContext(GrpcCallContext.of(pool, CallOptions.DEFAULT))
.build();
ServerStreamingCallSettings settings =
ServerStreamingCallSettings.<Color, Money>newBuilder().build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
*/
package com.google.api.gax.grpc;

import com.google.api.gax.core.BackgroundResource;
import com.google.api.gax.core.NoCredentialsProvider;
import com.google.api.gax.grpc.testing.FakeServiceGrpc;
import com.google.api.gax.retrying.RetrySettings;
import com.google.api.gax.retrying.StreamResumptionStrategy;
import com.google.api.gax.rpc.Callables;
import com.google.api.gax.rpc.ClientContext;
import com.google.api.gax.rpc.DeadlineExceededException;
import com.google.api.gax.rpc.FixedTransportChannelProvider;
Expand All @@ -49,6 +49,9 @@
import io.grpc.ServerBuilder;
import io.grpc.Status;
import io.grpc.stub.StreamObserver;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.junit.Test;
Expand All @@ -58,15 +61,13 @@

@RunWith(JUnit4.class)
public class GrpcDirectStreamControllerTest {
private static final int DEFAULT_AWAIT_TERMINATION_SEC = 10;

@Test(timeout = 180_000) // ms
public void testRetryNoRaceCondition() throws Exception {
Server server = ServerBuilder.forPort(1234).addService(new FakeService()).build();
server.start();

Server server = ServerBuilder.forPort(1234).addService(new FakeService()).build().start();
ManagedChannel channel =
ManagedChannelBuilder.forAddress("localhost", 1234).usePlaintext().build();

StreamResumptionStrategy<Color, Money> resumptionStrategy =
new StreamResumptionStrategy<Color, Money>() {
@Nonnull
Expand All @@ -92,58 +93,68 @@ public boolean canResume() {
return true;
}
};

// Set up retry settings. Set total timeout to 1 minute to limit the total runtime of this test.
// Set retry delay to 1 ms so the retries will be scheduled in a loop with no delays.
// Set max attempt to max so there could be as many retries as possible.
ServerStreamingCallSettings<Color, Money> callSettigs =
// Set up retry settings. Set total timeout to 1 minute to limit the total runtime of this
// test. Set retry delay to 1 ms so the retries will be scheduled in a loop with no delays.
ServerStreamingCallSettings<Color, Money> callSettings =
ServerStreamingCallSettings.<Color, Money>newBuilder()
.setResumptionStrategy(resumptionStrategy)
.setRetryableCodes(StatusCode.Code.DEADLINE_EXCEEDED)
.setRetrySettings(
RetrySettings.newBuilder()
.setTotalTimeout(Duration.ofMinutes(1))
.setMaxAttempts(Integer.MAX_VALUE)
.setInitialRpcTimeout(Duration.ofMillis(1))
.setMaxRpcTimeout(Duration.ofMillis(1))
.setInitialRetryDelay(Duration.ofMillis(1))
.setMaxRetryDelay(Duration.ofMillis(1))
.build())
.build();

StubSettings.Builder builder =
new StubSettings.Builder() {
@Override
public StubSettings build() {
return new StubSettings(this) {
@Override
public Builder toBuilder() {
throw new IllegalStateException();
}
};
}
};

builder
.setEndpoint("localhost:1234")
.setCredentialsProvider(NoCredentialsProvider.create())
.setTransportChannelProvider(
FixedTransportChannelProvider.create(GrpcTransportChannel.create(channel)));

ServerStreamingCallable<Color, Money> callable =
GrpcCallableFactory.createServerStreamingCallable(
GrpcCallSettings.create(FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE),
callSettigs,
ClientContext.create(builder.build()));

ServerStreamingCallable<Color, Money> retrying =
Callables.retrying(callable, callSettigs, ClientContext.create(builder.build()));

Color request = Color.newBuilder().getDefaultInstanceForType();

// Store a list of resources to manually close at the end of the test
List<BackgroundResource> backgroundResourceList = new ArrayList<>();
try {
for (Money money : retrying.call(request, GrpcCallContext.createDefault())) {}

GrpcTransportChannel transportChannel = GrpcTransportChannel.create(channel);
backgroundResourceList.add(transportChannel);

StubSettings.Builder builder =
new StubSettings.Builder() {
@Override
public StubSettings build() {
return new StubSettings(this) {
@Override
public Builder toBuilder() {
throw new IllegalStateException();
}
};
}
};

builder
.setEndpoint("localhost:1234")
.setCredentialsProvider(NoCredentialsProvider.create())
.setTransportChannelProvider(FixedTransportChannelProvider.create(transportChannel));

ClientContext clientContext = ClientContext.create(builder.build());
backgroundResourceList.addAll(clientContext.getBackgroundResources());
// GrpcCallableFactory's createServerStreamingCallable creates a retrying callable
ServerStreamingCallable<Color, Money> callable =
GrpcCallableFactory.createServerStreamingCallable(
GrpcCallSettings.create(FakeServiceGrpc.METHOD_SERVER_STREAMING_RECOGNIZE),
callSettings,
clientContext);

Color request = Color.newBuilder().getDefaultInstanceForType();
for (Money money : callable.call(request, clientContext.getDefaultCallContext())) {}
} catch (DeadlineExceededException e) {
// Ignore this error
} finally {
// Shutdown all the resources
server.shutdown();
server.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS);
channel.shutdown();
channel.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS);
for (BackgroundResource backgroundResource : backgroundResourceList) {
backgroundResource.shutdown();
backgroundResource.awaitTermination(DEFAULT_AWAIT_TERMINATION_SEC, TimeUnit.SECONDS);
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,17 +81,14 @@ public void serverStreamingRecognize(
// because the InProcessServer uses a direct executor and will buffer the results ignoring
// cancellation
Runnable runnable =
new Runnable() {
@Override
public void run() {
try {
Thread.sleep((long) color.getGreen());
} catch (InterruptedException e) {
Thread.interrupted();
return;
}
() -> {
try {
Thread.sleep((long) color.getGreen());
responseObserver.onNext(convert(color));
responseObserver.onCompleted();
} catch (Exception e) {
Thread.interrupted();
responseObserver.onError(e);
}
};

Expand All @@ -107,9 +104,10 @@ public StreamObserver<Color> clientStreamingRecognize(StreamObserver<Money> resp
}

private static Money convert(Color color) {
Money result =
Money.newBuilder().setCurrencyCode("USD").setUnits((long) (color.getRed() * 255)).build();
return result;
return Money.newBuilder()
.setCurrencyCode("USD")
.setUnits((long) (color.getRed() * 255))
.build();
}

private static class RequestStreamObserver implements StreamObserver<Color> {
Expand Down

0 comments on commit 8cbea70

Please sign in to comment.