Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor to obtain timeout from OperationContext for server selection #1209

Merged
merged 4 commits into from
Oct 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions config/spotbugs/exclude.xml
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,11 @@
<Method name="awaitOn"/>
<Bug pattern="RV_RETURN_VALUE_IGNORED_BAD_PRACTICE"/>
</Match>
<Match>
<Class name="com.mongodb.internal.time.Timeout"/>
<Method name="awaitOn"/>
<Bug pattern="RV_RETURN_VALUE_IGNORED"/>
</Match>
<Match>
<Class name="com.mongodb.internal.time.Timeout"/>
<Method name="awaitOn"/>
Expand Down
6 changes: 6 additions & 0 deletions driver-core/src/main/com/mongodb/internal/TimeoutContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package com.mongodb.internal;

import com.mongodb.internal.time.StartTime;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.Nullable;

Expand Down Expand Up @@ -152,4 +153,9 @@ private static Timeout calculateTimeout(@Nullable final Long timeoutMS) {
}
return null;
}

public Timeout startServerSelectionTimeout() {
long ms = getTimeoutSettings().getServerSelectionTimeoutMS();
return StartTime.now().timeoutAfterOrInfiniteIfNegative(ms, MILLISECONDS);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import com.mongodb.event.ServerDescriptionChangedEvent;
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.Nullable;
import org.bson.types.ObjectId;

Expand Down Expand Up @@ -122,7 +123,7 @@ public void close() {
}

@Override
public ClusterableServer getServer(final ServerAddress serverAddress) {
public ClusterableServer getServer(final ServerAddress serverAddress, final Timeout serverSelectionTimeout) {
isTrue("is open", !isClosed());

ServerTuple serverTuple = addressToServerTupleMap.get(serverAddress);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
import com.mongodb.internal.selector.LatencyMinimizingServerSelector;
import com.mongodb.internal.time.StartTime;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.Nullable;
import com.mongodb.selector.CompositeServerSelector;
Expand Down Expand Up @@ -106,20 +105,19 @@ public ServerTuple selectServer(final ServerSelector serverSelector, final Opera

ServerSelector compositeServerSelector = getCompositeServerSelector(serverSelector);
boolean selectionFailureLogged = false;
StartTime startTime = StartTime.now();
Timeout timeout = startServerSelectionTimeout(startTime);
Timeout timeout = operationContext.getTimeoutContext().startServerSelectionTimeout();

while (true) {
CountDownLatch currentPhaseLatch = phase.get();
ClusterDescription currentDescription = description;
ServerTuple serverTuple = selectServer(compositeServerSelector, currentDescription);
ServerTuple serverTuple = selectServer(compositeServerSelector, currentDescription, timeout);

throwIfIncompatible(currentDescription);
if (serverTuple != null) {
return serverTuple;
}
if (timeout.hasExpired()) {
throw createTimeoutException(serverSelector, currentDescription, startTime);
throw createTimeoutException(serverSelector, currentDescription);
}
if (!selectionFailureLogged) {
logServerSelectionFailure(serverSelector, currentDescription, timeout);
Expand All @@ -140,10 +138,9 @@ public void selectServerAsync(final ServerSelector serverSelector, final Operati
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(format("Asynchronously selecting server with selector %s", serverSelector));
}
StartTime startTime = StartTime.now();
Timeout timeout = startServerSelectionTimeout(startTime);
Timeout timeout = operationContext.getTimeoutContext().startServerSelectionTimeout();
ServerSelectionRequest request = new ServerSelectionRequest(
serverSelector, getCompositeServerSelector(serverSelector), timeout, startTime, callback);
serverSelector, getCompositeServerSelector(serverSelector), timeout, callback);

CountDownLatch currentPhase = phase.get();
ClusterDescription currentDescription = description;
Expand Down Expand Up @@ -218,11 +215,6 @@ private void updatePhase() {
withLock(() -> phase.getAndSet(new CountDownLatch(1)).countDown());
}

private Timeout startServerSelectionTimeout(final StartTime startTime) {
long ms = settings.getServerSelectionTimeout(MILLISECONDS);
return startTime.timeoutAfterOrInfiniteIfNegative(ms, MILLISECONDS);
}

private Timeout startMinWaitHeartbeatTimeout() {
long minHeartbeatFrequency = serverFactory.getSettings().getMinHeartbeatFrequency(NANOSECONDS);
minHeartbeatFrequency = Math.max(0, minHeartbeatFrequency);
Expand All @@ -244,7 +236,7 @@ private boolean handleServerSelectionRequest(
return true;
}

ServerTuple serverTuple = selectServer(request.compositeSelector, description);
ServerTuple serverTuple = selectServer(request.compositeSelector, description, request.getTimeout());
if (serverTuple != null) {
if (LOGGER.isTraceEnabled()) {
LOGGER.trace(format("Asynchronously selected server %s",
Expand All @@ -262,8 +254,7 @@ private boolean handleServerSelectionRequest(
if (LOGGER.isTraceEnabled()) {
LOGGER.trace("Asynchronously failed server selection after timeout");
}
request.onResult(null, createTimeoutException(
request.originalSelector, description, request.getStartTime()));
request.onResult(null, createTimeoutException(request.originalSelector, description));
return true;
}

Expand All @@ -289,8 +280,11 @@ private void logServerSelectionFailure(final ServerSelector serverSelector,

@Nullable
private ServerTuple selectServer(final ServerSelector serverSelector,
final ClusterDescription clusterDescription) {
return selectServer(serverSelector, clusterDescription, this::getServer);
final ClusterDescription clusterDescription, final Timeout serverSelectionTimeout) {
return selectServer(
serverSelector,
clusterDescription,
serverAddress -> getServer(serverAddress, serverSelectionTimeout));
}

@Nullable
Expand Down Expand Up @@ -369,10 +363,9 @@ private MongoIncompatibleDriverException createIncompatibleException(final Clust
}

private MongoTimeoutException createTimeoutException(final ServerSelector serverSelector,
final ClusterDescription curDescription, final StartTime startTime) {
final ClusterDescription curDescription) {
return new MongoTimeoutException(format(
"Timed out after %d ms while waiting for a server that matches %s. Client view of cluster state is %s",
startTime.elapsed().toMillis(),
"Timed out while waiting for a server that matches %s. Client view of cluster state is %s",
serverSelector,
curDescription.getShortDescription()));
}
Expand All @@ -382,15 +375,13 @@ private static final class ServerSelectionRequest {
private final ServerSelector compositeSelector;
private final SingleResultCallback<ServerTuple> callback;
private final Timeout timeout;
private final StartTime startTime;
private CountDownLatch phase;

ServerSelectionRequest(final ServerSelector serverSelector, final ServerSelector compositeSelector,
final Timeout timeout, final StartTime startTime, final SingleResultCallback<ServerTuple> callback) {
final Timeout timeout, final SingleResultCallback<ServerTuple> callback) {
this.originalSelector = serverSelector;
this.compositeSelector = compositeSelector;
this.timeout = timeout;
this.startTime = startTime;
this.callback = callback;
}

Expand All @@ -405,10 +396,6 @@ void onResult(@Nullable final ServerTuple serverTuple, @Nullable final Throwable
Timeout getTimeout() {
return timeout;
}

StartTime getStartTime() {
return startTime;
}
}

private void notifyWaitQueueHandler(final ServerSelectionRequest request) {
Expand Down Expand Up @@ -438,6 +425,10 @@ private void stopWaitQueueHandler() {
}

private final class WaitQueueHandler implements Runnable {

WaitQueueHandler() {
}

public void run() {
while (!isClosed) {
CountDownLatch currentPhase = phase.get();
Expand All @@ -446,12 +437,12 @@ public void run() {
Timeout timeout = Timeout.infinite();

for (Iterator<ServerSelectionRequest> iter = waitQueue.iterator(); iter.hasNext();) {
ServerSelectionRequest nextRequest = iter.next();
if (handleServerSelectionRequest(nextRequest, currentPhase, curDescription)) {
ServerSelectionRequest currentRequest = iter.next();
if (handleServerSelectionRequest(currentRequest, currentPhase, curDescription)) {
iter.remove();
} else {
timeout = timeout
.orEarlier(nextRequest.getTimeout())
.orEarlier(currentRequest.getTimeout())
.orEarlier(startMinWaitHeartbeatTimeout());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import com.mongodb.internal.async.SingleResultCallback;
import com.mongodb.connection.ClusterDescription;
import com.mongodb.connection.ClusterSettings;
import com.mongodb.internal.time.Timeout;
import com.mongodb.lang.Nullable;
import com.mongodb.selector.ServerSelector;

Expand All @@ -45,7 +46,7 @@ public interface Cluster extends Closeable {

@Nullable
@VisibleForTesting(otherwise = PRIVATE)
ClusterableServer getServer(ServerAddress serverAddress);
ClusterableServer getServer(ServerAddress serverAddress, Timeout serverSelectionTimeout);

/**
* Get the current description of this cluster.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@
import static java.lang.String.format;
import static java.util.Collections.emptyList;
import static java.util.Collections.singletonList;
import static java.util.concurrent.TimeUnit.MILLISECONDS;

@ThreadSafe
final class LoadBalancedCluster implements Cluster {
Expand Down Expand Up @@ -181,9 +180,9 @@ public ClusterId getClusterId() {
}

@Override
public ClusterableServer getServer(final ServerAddress serverAddress) {
public ClusterableServer getServer(final ServerAddress serverAddress, final Timeout serverSelectionTimeout) {
isTrue("open", !isClosed());
waitForSrv();
waitForSrv(serverSelectionTimeout);
return assertNotNull(server);
}

Expand All @@ -202,29 +201,27 @@ public ClusterClock getClock() {
@Override
public ServerTuple selectServer(final ServerSelector serverSelector, final OperationContext operationContext) {
isTrue("open", !isClosed());
waitForSrv();
Timeout serverSelectionTimeout = operationContext.getTimeoutContext().startServerSelectionTimeout();
waitForSrv(serverSelectionTimeout);
if (srvRecordResolvedToMultipleHosts) {
throw createResolvedToMultipleHostsException();
}
return new ServerTuple(assertNotNull(server), description.getServerDescriptions().get(0));
}


private void waitForSrv() {
private void waitForSrv(final Timeout serverSelectionTimeout) {
if (initializationCompleted) {
return;
}
Locks.withLock(lock, () -> {
StartTime startTime = StartTime.now();
Timeout timeout = startServerSelectionTimeout(startTime);
while (!initializationCompleted) {
if (isClosed()) {
throw createShutdownException();
}
if (timeout.hasExpired()) {
throw createTimeoutException(startTime);
if (serverSelectionTimeout.hasExpired()) {
throw createTimeoutException();
}
timeout.awaitOn(condition, () -> format("resolving SRV records for %s", settings.getSrvHost()));
serverSelectionTimeout.awaitOn(condition, () -> format("resolving SRV records for %s", settings.getSrvHost()));
}
});
}
Expand All @@ -236,7 +233,7 @@ public void selectServerAsync(final ServerSelector serverSelector, final Operati
callback.onResult(null, createShutdownException());
return;
}
Timeout timeout = startServerSelectionTimeout(StartTime.now());
Timeout timeout = operationContext.getTimeoutContext().startServerSelectionTimeout();
ServerSelectionRequest serverSelectionRequest = new ServerSelectionRequest(timeout, callback);
if (initializationCompleted) {
handleServerSelectionRequest(serverSelectionRequest);
Expand Down Expand Up @@ -296,23 +293,18 @@ private MongoClientException createResolvedToMultipleHostsException() {
+ "to multiple hosts");
}

private MongoTimeoutException createTimeoutException(final StartTime startTime) {
private MongoTimeoutException createTimeoutException() {
MongoException localSrvResolutionException = srvResolutionException;
if (localSrvResolutionException == null) {
return new MongoTimeoutException(format("Timed out after %d ms while waiting to resolve SRV records for %s.",
startTime.elapsed().toMillis(), settings.getSrvHost()));
return new MongoTimeoutException(format("Timed out while waiting to resolve SRV records for %s.",
settings.getSrvHost()));
} else {
return new MongoTimeoutException(format("Timed out after %d ms while waiting to resolve SRV records for %s. "
return new MongoTimeoutException(format("Timed out while waiting to resolve SRV records for %s. "
+ "Resolution exception was '%s'",
startTime.elapsed().toMillis(), settings.getSrvHost(), localSrvResolutionException));
settings.getSrvHost(), localSrvResolutionException));
}
}

private Timeout startServerSelectionTimeout(final StartTime startTime) {
long ms = settings.getServerSelectionTimeout(MILLISECONDS);
return startTime.timeoutAfterOrInfiniteIfNegative(ms, MILLISECONDS);
}

private void notifyWaitQueueHandler(final ServerSelectionRequest request) {
Locks.withLock(lock, () -> {
if (isClosed()) {
Expand All @@ -338,7 +330,6 @@ private void notifyWaitQueueHandler(final ServerSelectionRequest request) {

private final class WaitQueueHandler implements Runnable {
public void run() {
StartTime startTime = StartTime.now();
List<ServerSelectionRequest> timeoutList = new ArrayList<>();
while (!(isClosed() || initializationCompleted)) {
lockInterruptibly(lock);
Expand Down Expand Up @@ -369,7 +360,7 @@ public void run() {
} finally {
lock.unlock();
}
timeoutList.forEach(request -> request.onError(createTimeoutException(startTime)));
timeoutList.forEach(request -> request.onError(createTimeoutException()));
timeoutList.clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.mongodb.internal.diagnostics.logging.Logger;
import com.mongodb.internal.diagnostics.logging.Loggers;
import com.mongodb.event.ServerDescriptionChangedEvent;
import com.mongodb.internal.time.Timeout;

import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -69,7 +70,7 @@ protected void connect() {
}

@Override
public ClusterableServer getServer(final ServerAddress serverAddress) {
public ClusterableServer getServer(final ServerAddress serverAddress, final Timeout serverSelectionTimeout) {
isTrue("open", !isClosed());
return assertNotNull(server.get());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,10 @@ public void shouldGetServerWithOkDescription() {
@Test
public void shouldSuccessfullyQueryASecondaryWithPrimaryReadPreference() {
// given
OperationContext operationContext = OPERATION_CONTEXT;
ServerAddress secondary = getSecondary();
setUpCluster(secondary);
String collectionName = getClass().getName();
OperationContext operationContext = OPERATION_CONTEXT;
Connection connection = cluster.selectServer(new ServerAddressSelector(secondary), operationContext).getServer()
.getConnection(operationContext);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import com.mongodb.connection.ServerType;
import com.mongodb.event.ClusterListener;
import com.mongodb.internal.connection.SdamServerDescriptionManager.SdamIssue;
import com.mongodb.internal.time.Timeout;
import org.bson.BsonArray;
import org.bson.BsonDocument;
import org.bson.BsonInt32;
Expand All @@ -42,6 +43,7 @@
import java.util.List;
import java.util.concurrent.TimeUnit;

import static com.mongodb.ClusterFixture.OPERATION_CONTEXT;
import static com.mongodb.connection.ServerConnectionState.CONNECTING;
import static com.mongodb.internal.connection.DescriptionHelper.createServerDescription;
import static com.mongodb.internal.connection.ProtocolHelper.getCommandFailureException;
Expand Down Expand Up @@ -79,14 +81,15 @@ protected void applyResponse(final BsonArray response) {
}

protected void applyApplicationError(final BsonDocument applicationError) {
Timeout serverSelectionTimeout = OPERATION_CONTEXT.getTimeoutContext().startServerSelectionTimeout();
ServerAddress serverAddress = new ServerAddress(applicationError.getString("address").getValue());
int errorGeneration = applicationError.getNumber("generation",
new BsonInt32(((DefaultServer) getCluster().getServer(serverAddress)).getConnectionPool().getGeneration())).intValue();
new BsonInt32(((DefaultServer) getCluster().getServer(serverAddress, serverSelectionTimeout)).getConnectionPool().getGeneration())).intValue();
int maxWireVersion = applicationError.getNumber("maxWireVersion").intValue();
String when = applicationError.getString("when").getValue();
String type = applicationError.getString("type").getValue();

DefaultServer server = (DefaultServer) cluster.getServer(serverAddress);
DefaultServer server = (DefaultServer) cluster.getServer(serverAddress, serverSelectionTimeout);
RuntimeException exception;

switch (type) {
Expand Down
Loading