Skip to content

Commit

Permalink
Fix DataNodeRequestSender (elastic#121999)
Browse files Browse the repository at this point in the history
There are two issues in the current implementation:

1. We should use the list of shardIds from the request, rather than all
targets, when removing failures for shards that have been successfully
executed.

2. We should remove shardIds from the pending list once a failure is reported
and abort execution at that point, as the results will be discarded.

Closes elastic#121966
  • Loading branch information
dnhatn committed Feb 11, 2025
1 parent d3ff1cf commit ad780e2
Showing 1 changed file with 26 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,9 @@
import org.elasticsearch.xpack.esql.action.EsqlSearchShardsAction;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
Expand All @@ -58,6 +60,7 @@ abstract class DataNodeRequestSender {
private final Map<DiscoveryNode, Semaphore> nodePermits = new HashMap<>();
private final Map<ShardId, ShardFailure> shardFailures = ConcurrentCollections.newConcurrentMap();
private final AtomicBoolean changed = new AtomicBoolean();
private boolean reportedFailure = false; // guarded by sendingLock

DataNodeRequestSender(TransportService transportService, Executor esqlExecutor, CancellableTask rootTask) {
this.transportService = transportService;
Expand Down Expand Up @@ -117,11 +120,14 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
);
}
}
if (shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal)) {
for (var e : shardFailures.values()) {
computeListener.acquireAvoid().onFailure(e.failure);
}
if (reportedFailure || shardFailures.values().stream().anyMatch(shardFailure -> shardFailure.fatal)) {
reportedFailure = true;
reportFailures(computeListener);
} else {
pendingShardIds.removeIf(shr -> {
var failure = shardFailures.get(shr);
return failure != null && failure.fatal;
});
var nodeRequests = selectNodeRequests(targetShards);
for (NodeRequest request : nodeRequests) {
sendOneNodeRequest(targetShards, computeListener, request);
Expand All @@ -136,6 +142,20 @@ private void trySendingRequestsForPendingShards(TargetShards targetShards, Compu
}
}

private void reportFailures(ComputeListener computeListener) {
assert sendingLock.isHeldByCurrentThread();
assert reportedFailure;
Iterator<ShardFailure> it = shardFailures.values().iterator();
Set<Exception> seen = Collections.newSetFromMap(new IdentityHashMap<>());
while (it.hasNext()) {
ShardFailure failure = it.next();
if (seen.add(failure.failure)) {
computeListener.acquireAvoid().onFailure(failure.failure);
}
it.remove();
}
}

private void sendOneNodeRequest(TargetShards targetShards, ComputeListener computeListener, NodeRequest request) {
final ActionListener<List<DriverProfile>> listener = computeListener.acquireCompute();
sendRequest(request.node, request.shardIds, request.aliasFilters, new NodeListener() {
Expand All @@ -148,7 +168,7 @@ void onAfter(List<DriverProfile> profiles) {
@Override
public void onResponse(DataNodeComputeResponse response) {
// remove failures of successful shards
for (ShardId shardId : targetShards.shardIds()) {
for (ShardId shardId : request.shardIds()) {
if (response.shardLevelFailures().containsKey(shardId) == false) {
shardFailures.remove(shardId);
}
Expand Down Expand Up @@ -250,6 +270,7 @@ private List<NodeRequest> selectNodeRequests(TargetShards targetShards) {
final Iterator<ShardId> shardsIt = pendingShardIds.iterator();
while (shardsIt.hasNext()) {
ShardId shardId = shardsIt.next();
assert shardFailures.get(shardId) == null || shardFailures.get(shardId).fatal == false;
TargetShard shard = targetShards.getShard(shardId);
Iterator<DiscoveryNode> nodesIt = shard.remainingNodes.iterator();
DiscoveryNode selectedNode = null;
Expand Down

0 comments on commit ad780e2

Please sign in to comment.