Skip to content

Commit

Permalink
PR fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
katcharov committed Oct 1, 2024
1 parent 14b9fee commit 7fcbbee
Show file tree
Hide file tree
Showing 8 changed files with 104 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
import com.mongodb.annotations.Immutable;
import com.mongodb.lang.Nullable;

import java.nio.channels.AsynchronousChannelGroup;
import java.util.concurrent.ExecutorService;

import static com.mongodb.assertions.Assertions.notNull;
Expand Down Expand Up @@ -54,13 +53,17 @@ private Builder() {
}

/**
* Sets the executor service. This executor service will not be shut
* down by the driver code, and must be shut down by application code.
* The executor service, intended to be used exclusively by the mongo
* client. Closing the mongo client will result in orderly shutdown
* of the executor service.
*
* <p>When TLS is not enabled, see
* {@link java.nio.channels.AsynchronousChannelGroup#withThreadPool(ExecutorService)}
* for additional requirements for the executor service.
*
* @param executorService the executor service
* @return this
* @see #getExecutorService()
* @see AsynchronousChannelGroup#withThreadPool(ExecutorService)
*/
public Builder executorService(final ExecutorService executorService) {
this.executorService = notNull("executorService", executorService);
Expand Down Expand Up @@ -89,8 +92,8 @@ public ExecutorService getExecutorService() {

@Override
public String toString() {
return "AsyncTransportSettings{" +
"executorService=" + executorService +
'}';
return "AsyncTransportSettings{"
+ "executorService=" + executorService
+ '}';
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright 2008-present MongoDB, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.mongodb.internal;

import com.mongodb.internal.function.CheckedSupplier;

/**
* <p>This class is not part of the public API and may be removed or changed at any time</p>
*/
public class ValueOrExceptionContainer<T> {
private final T value;
private final Exception exception;

public ValueOrExceptionContainer(final CheckedSupplier<T, Exception> supplier) {
T value = null;
Exception exception = null;
try {
value = supplier.get();
} catch (Exception e) {
exception = e;
}
this.value = value;
this.exception = exception;
}

public T get() throws Exception {
if (isCompletedExceptionally()) {
throw exception;
}
return value;
}

public boolean isCompletedExceptionally() {
return exception != null;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import com.mongodb.ServerAddress;
import com.mongodb.connection.AsyncCompletionHandler;
import com.mongodb.connection.SocketSettings;
import com.mongodb.internal.ValueOrExceptionContainer;
import com.mongodb.lang.Nullable;
import com.mongodb.spi.dns.InetAddressResolver;

Expand All @@ -33,7 +34,6 @@
import java.nio.channels.CompletionHandler;
import java.util.LinkedList;
import java.util.Queue;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
Expand All @@ -49,7 +49,7 @@ public final class AsynchronousSocketChannelStream extends AsynchronousChannelSt
private final InetAddressResolver inetAddressResolver;
private final SocketSettings settings;
@Nullable
private final ExecutorService executorService;
private final ValueOrExceptionContainer<AsynchronousChannelGroup> group;

public AsynchronousSocketChannelStream(
final ServerAddress serverAddress, final InetAddressResolver inetAddressResolver,
Expand All @@ -60,12 +60,12 @@ public AsynchronousSocketChannelStream(
public AsynchronousSocketChannelStream(
final ServerAddress serverAddress, final InetAddressResolver inetAddressResolver,
final SocketSettings settings, final PowerOfTwoBufferPool bufferProvider,
@Nullable final ExecutorService executorService) {
@Nullable final ValueOrExceptionContainer<AsynchronousChannelGroup> group) {
super(serverAddress, settings, bufferProvider);
this.serverAddress = serverAddress;
this.inetAddressResolver = inetAddressResolver;
this.settings = settings;
this.executorService = executorService;
this.group = group;
}

@Override
Expand All @@ -91,9 +91,8 @@ private void initializeSocketChannel(final AsyncCompletionHandler<Void> handler,

try {
AsynchronousSocketChannel attemptConnectionChannel;
if (executorService != null) {
AsynchronousChannelGroup group = AsynchronousChannelGroup.withThreadPool(executorService);
attemptConnectionChannel = AsynchronousSocketChannel.open(group);
if (group != null) {
attemptConnectionChannel = AsynchronousSocketChannel.open(group.get());
} else {
attemptConnectionChannel = AsynchronousSocketChannel.open();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,11 @@
import com.mongodb.ServerAddress;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.ValueOrExceptionContainer;
import com.mongodb.lang.Nullable;
import com.mongodb.spi.dns.InetAddressResolver;

import java.util.concurrent.ExecutorService;
import java.nio.channels.AsynchronousChannelGroup;

import static com.mongodb.assertions.Assertions.assertFalse;
import static com.mongodb.assertions.Assertions.notNull;
Expand All @@ -35,7 +36,7 @@ public class AsynchronousSocketChannelStreamFactory implements StreamFactory {
private final SocketSettings settings;
private final InetAddressResolver inetAddressResolver;
@Nullable
private final ExecutorService executorService;
private final ValueOrExceptionContainer<AsynchronousChannelGroup> group;

/**
* Create a new factory with the default {@code BufferProvider} and {@code AsynchronousChannelGroup}.
Expand All @@ -51,17 +52,17 @@ public AsynchronousSocketChannelStreamFactory(

AsynchronousSocketChannelStreamFactory(
final InetAddressResolver inetAddressResolver, final SocketSettings settings,
final SslSettings sslSettings, @Nullable final ExecutorService executorService) {
final SslSettings sslSettings, @Nullable final ValueOrExceptionContainer<AsynchronousChannelGroup> group) {
assertFalse(sslSettings.isEnabled());
this.inetAddressResolver = inetAddressResolver;
this.settings = notNull("settings", settings);
this.executorService = executorService;
this.group = group;
}

@Override
public Stream create(final ServerAddress serverAddress) {
return new AsynchronousSocketChannelStream(
serverAddress, inetAddressResolver, settings, bufferProvider, executorService);
serverAddress, inetAddressResolver, settings, bufferProvider, group);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@

import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.SslSettings;
import com.mongodb.internal.ValueOrExceptionContainer;
import com.mongodb.lang.Nullable;
import com.mongodb.spi.dns.InetAddressResolver;

import java.util.concurrent.ExecutorService;
import java.nio.channels.AsynchronousChannelGroup;

/**
* A {@code StreamFactoryFactory} implementation for AsynchronousSocketChannel-based streams.
Expand All @@ -31,26 +32,33 @@
public final class AsynchronousSocketChannelStreamFactoryFactory implements StreamFactoryFactory {
private final InetAddressResolver inetAddressResolver;
@Nullable
private final ExecutorService executorService;
private final ValueOrExceptionContainer<AsynchronousChannelGroup> group;

public AsynchronousSocketChannelStreamFactoryFactory(final InetAddressResolver inetAddressResolver) {
this(inetAddressResolver, null);
}

AsynchronousSocketChannelStreamFactoryFactory(
final InetAddressResolver inetAddressResolver,
@Nullable final ExecutorService executorService) {
@Nullable final ValueOrExceptionContainer<AsynchronousChannelGroup> group) {
this.inetAddressResolver = inetAddressResolver;
this.executorService = executorService;
this.group = group;
}

@Override
public StreamFactory create(final SocketSettings socketSettings, final SslSettings sslSettings) {
return new AsynchronousSocketChannelStreamFactory(
inetAddressResolver, socketSettings, sslSettings, executorService);
inetAddressResolver, socketSettings, sslSettings, group);
}

@Override
public void close() {
if (group != null && !group.isCompletedExceptionally()) {
try {
group.get().shutdown();
} catch (Exception e) {
// will not occur, since it was not completed exceptionally
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import com.mongodb.connection.NettyTransportSettings;
import com.mongodb.connection.SocketSettings;
import com.mongodb.connection.TransportSettings;
import com.mongodb.internal.ValueOrExceptionContainer;
import com.mongodb.internal.connection.netty.NettyStreamFactoryFactory;
import com.mongodb.spi.dns.InetAddressResolver;

import java.nio.channels.AsynchronousChannelGroup;
import java.util.concurrent.ExecutorService;

/**
Expand Down Expand Up @@ -57,7 +59,9 @@ public static StreamFactoryFactory getAsyncStreamFactoryFactory(final MongoClien
if (settings.getSslSettings().isEnabled()) {
return new TlsChannelStreamFactoryFactory(inetAddressResolver, executorService);
} else {
return new AsynchronousSocketChannelStreamFactoryFactory(inetAddressResolver, executorService);
ValueOrExceptionContainer<AsynchronousChannelGroup> group = new ValueOrExceptionContainer<>(
() -> AsynchronousChannelGroup.withThreadPool(executorService));
return new AsynchronousSocketChannelStreamFactoryFactory(inetAddressResolver, group);
}
} else if (transportSettings instanceof NettyTransportSettings) {
return getNettyStreamFactoryFactory(inetAddressResolver, (NettyTransportSettings) transportSettings);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ public class AsynchronousTlsChannelGroup {
private static final int queueLengthMultiplier = 32;

private static final AtomicInteger globalGroupCount = new AtomicInteger();
private final boolean executorIsExternal;

class RegisteredSocket {

Expand Down Expand Up @@ -210,11 +209,9 @@ public AsynchronousTlsChannelGroup(@Nullable final ExecutorService executorServi
}
timeoutExecutor.setRemoveOnCancelPolicy(true);
if (executorService != null) {
this.executorIsExternal = true;
this.executor = executorService;
} else {
int nThreads = Runtime.getRuntime().availableProcessors();
this.executorIsExternal = false;
this.executor = new ThreadPoolExecutor(
nThreads,
nThreads,
Expand Down Expand Up @@ -424,9 +421,7 @@ private void loop() {
} catch (Throwable e) {
LOGGER.error("error in selector loop", e);
} finally {
if (!executorIsExternal) {
executor.shutdown();
}
executor.shutdown();
// use shutdownNow to stop delayed tasks
timeoutExecutor.shutdownNow();
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,17 @@
import com.mongodb.connection.TransportSettings;
import com.mongodb.reactivestreams.client.syncadapter.SyncMongoClient;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.ValueSource;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import static com.mongodb.client.Fixture.getMongoClientSettingsBuilder;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

class AsyncTransportSettingsTest {
Expand All @@ -51,20 +53,26 @@ void testAsyncTransportSettings() {
verify(executorService, atLeastOnce()).execute(any());
}

@Test
void testExternalExecutorNotShutDown() {
@ParameterizedTest
@ValueSource(booleans = {true, false})
void testExternalExecutorWasShutDown(final boolean tlsEnabled) {
ExecutorService executorService = spy(Executors.newFixedThreadPool(5));
AsyncTransportSettings asyncTransportSettings = TransportSettings.asyncBuilder()
.executorService(executorService)
.build();
MongoClientSettings mongoClientSettings = getMongoClientSettingsBuilder()
.applyToSslSettings(builder -> builder.enabled(true))
.applyToSslSettings(builder -> builder.enabled(tlsEnabled))
.transportSettings(asyncTransportSettings)
.build();

try (MongoClient ignored = new SyncMongoClient(MongoClients.create(mongoClientSettings))) {
// ignored
}
verify(executorService, never()).shutdown();
try {
Thread.sleep(100);
} catch (InterruptedException e) {
throw new RuntimeException(e);
}
verify(executorService, times(1)).shutdown();
}
}

0 comments on commit 7fcbbee

Please sign in to comment.