Skip to content

Commit

Permalink
Retry broadcast OOM with BHJ disabled within the same spark session
Browse files Browse the repository at this point in the history
Presto on Spark uses temp storage for storing and distributing
broadcast tables. Spark driver performs the necessary threshold
checks on broadcast table and if the size is over the threshold,
the query fails with broadcast oom. The only way to fix this
failure is to disable broadcast join in the query.

As we are able to detect broadcast OOM on driver confidently,
we can just disable broadcast join, replan and resubmit the
query for execution. This can happen within the same spark
session itself and thus would not need any users intervention
for fixing such failures.
  • Loading branch information
pgupta2 committed Mar 28, 2022
1 parent 603d617 commit af8feb8
Show file tree
Hide file tree
Showing 16 changed files with 635 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Optional;

import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_GLOBAL_MEMORY_LIMIT;
import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_LOCAL_BROADCAST_JOIN_MEMORY_LIMIT;
import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_LOCAL_MEMORY_LIMIT;
import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_REVOCABLE_MEMORY_LIMIT;
import static com.facebook.presto.util.HeapDumper.dumpHeap;
Expand Down Expand Up @@ -51,7 +52,7 @@ public static ExceededMemoryLimitException exceededLocalUserMemoryLimit(

public static ExceededMemoryLimitException exceededLocalBroadcastMemoryLimit(DataSize maxMemory, String additionalFailureInfo)
{
return new ExceededMemoryLimitException(EXCEEDED_LOCAL_MEMORY_LIMIT,
return new ExceededMemoryLimitException(EXCEEDED_LOCAL_BROADCAST_JOIN_MEMORY_LIMIT,
format("Query exceeded per-node broadcast memory limit of %s [%s]", maxMemory, additionalFailureInfo));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class PrestoSparkConfig
private int splitAssignmentBatchSize = 1_000_000;
private double memoryRevokingThreshold;
private double memoryRevokingTarget;
private boolean retryOnOutOfMemoryBroadcastJoinEnabled;

public boolean isSparkPartitionCountAutoTuneEnabled()
{
Expand Down Expand Up @@ -225,4 +226,17 @@ public PrestoSparkConfig setMemoryRevokingTarget(double memoryRevokingTarget)
this.memoryRevokingTarget = memoryRevokingTarget;
return this;
}

public boolean isRetryOnOutOfMemoryBroadcastJoinEnabled()
{
return retryOnOutOfMemoryBroadcastJoinEnabled;
}

@Config("spark.retry-on-out-of-memory-broadcast-join-enabled")
@ConfigDescription("Disable broadcast join on broadcast OOM and re-submit the query again within the same spark session")
public PrestoSparkConfig setRetryOnOutOfMemoryBroadcastJoinEnabled(boolean retryOnOutOfMemoryBroadcastJoinEnabled)
{
this.retryOnOutOfMemoryBroadcastJoinEnabled = retryOnOutOfMemoryBroadcastJoinEnabled;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@
import static com.facebook.presto.spark.SparkErrorCode.UNSUPPORTED_STORAGE_TYPE;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.collectScalaIterator;
import static com.facebook.presto.spark.classloader_interface.ScalaUtils.emptyScalaIterator;
import static com.facebook.presto.spark.util.PrestoSparkFailureUtils.toPrestoSparkFailure;
import static com.facebook.presto.spark.util.PrestoSparkUtils.classTag;
import static com.facebook.presto.spark.util.PrestoSparkUtils.computeNextTimeout;
import static com.facebook.presto.spark.util.PrestoSparkUtils.createPagesSerde;
Expand Down Expand Up @@ -510,7 +511,7 @@ public IPrestoSparkQueryExecution create(
log.error(eventFailure, "Error publishing query immediate failure event");
}

throw failureInfo.get().toFailure();
throw toPrestoSparkFailure(session, failureInfo.get());
}
}

Expand Down Expand Up @@ -1005,7 +1006,7 @@ else if (executionException instanceof TimeoutException) {
log.error(eventFailure, "Error publishing query completed event");
}

throw failureInfo.get().toFailure();
throw toPrestoSparkFailure(session, failureInfo.get());
}

processShuffleStats();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
*/
package com.facebook.presto.spark;

import com.facebook.airlift.log.Logger;
import com.facebook.presto.server.SessionContext;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSession;
import com.facebook.presto.spark.classloader_interface.RetryExecutionStrategy;
import com.facebook.presto.spi.function.SqlFunctionId;
import com.facebook.presto.spi.function.SqlInvokedFunction;
import com.facebook.presto.spi.security.Identity;
Expand All @@ -27,15 +29,21 @@

import javax.annotation.Nullable;

import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import static com.facebook.presto.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE;
import static com.facebook.presto.spark.classloader_interface.RetryExecutionStrategy.DISABLE_BROADCAST_JOIN;
import static com.facebook.presto.sql.analyzer.FeaturesConfig.JoinDistributionType;
import static java.util.Objects.requireNonNull;

public class PrestoSparkSessionContext
implements SessionContext
{
private static final Logger log = Logger.get(PrestoSparkSessionContext.class);

private final Identity identity;
private final String catalog;
private final String schema;
Expand Down Expand Up @@ -78,11 +86,25 @@ public static PrestoSparkSessionContext createFromSessionInfo(
prestoSparkSession.getClientTags(),
prestoSparkSession.getTimeZoneId().orElse(null),
prestoSparkSession.getLanguage().orElse(null),
prestoSparkSession.getSystemProperties(),
getFinalSystemProperties(prestoSparkSession.getSystemProperties(), prestoSparkSession.getRetryExecutionStrategy()),
prestoSparkSession.getCatalogSessionProperties(),
prestoSparkSession.getTraceToken());
}

private static Map<String, String> getFinalSystemProperties(Map<String, String> systemProperties, Optional<RetryExecutionStrategy> retryExecutionStrategy)
{
if (!retryExecutionStrategy.isPresent()) {
return systemProperties;
}

log.info("Applying retryExecutionStrategy: " + retryExecutionStrategy.get().name());
Map<String, String> retrySystemProperties = new HashMap<>(systemProperties);
if (retryExecutionStrategy.get() == DISABLE_BROADCAST_JOIN) {
retrySystemProperties.put(JOIN_DISTRIBUTION_TYPE, JoinDistributionType.PARTITIONED.name());
}
return retrySystemProperties;
}

public PrestoSparkSessionContext(
Identity identity,
String catalog,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ public class PrestoSparkSessionProperties
public static final String SPARK_SPLIT_ASSIGNMENT_BATCH_SIZE = "spark_split_assignment_batch_size";
public static final String SPARK_MEMORY_REVOKING_THRESHOLD = "spark_memory_revoking_threshold";
public static final String SPARK_MEMORY_REVOKING_TARGET = "spark_memory_revoking_target";
public static final String SPARK_RETRY_ON_OUT_OF_MEMORY_BROADCAST_JOIN_ENABLED = "spark_retry_on_out_of_memory_broadcast_join_enabled";

private final List<PropertyMetadata<?>> sessionProperties;

Expand Down Expand Up @@ -107,6 +108,11 @@ public PrestoSparkSessionProperties(PrestoSparkConfig prestoSparkConfig)
SPARK_MEMORY_REVOKING_TARGET,
"When revoking memory, try to revoke so much that memory pool is filled below target at the end",
prestoSparkConfig.getMemoryRevokingTarget(),
false),
booleanProperty(
SPARK_RETRY_ON_OUT_OF_MEMORY_BROADCAST_JOIN_ENABLED,
"Disable broadcast join on broadcast OOM and re-submit the query again within the same spark session",
prestoSparkConfig.isRetryOnOutOfMemoryBroadcastJoinEnabled(),
false));
}

Expand Down Expand Up @@ -174,4 +180,9 @@ public static double getMemoryRevokingTarget(Session session)
{
return session.getSystemProperty(SPARK_MEMORY_REVOKING_TARGET, Double.class);
}

public static boolean isRetryOnOutOfMemoryBroadcastJoinEnabled(Session session)
{
return session.getSystemProperty(SPARK_RETRY_ON_OUT_OF_MEMORY_BROADCAST_JOIN_ENABLED, Boolean.class);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Optional;

import static com.facebook.airlift.concurrent.MoreFutures.getFutureValue;
import static com.facebook.presto.spark.util.PrestoSparkFailureUtils.toPrestoSparkFailure;
import static com.facebook.presto.util.Failures.toFailure;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;
Expand Down Expand Up @@ -93,7 +94,7 @@ public List<List<Object>> execute()
Optional<ExecutionFailureInfo> failureInfo = Optional.of(toFailure(executionException));
queryStateTimer.endQuery();

throw failureInfo.get().toFailure();
throw toPrestoSparkFailure(session, failureInfo.get());
}
return Collections.emptyList();
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.spark.util;

import com.facebook.presto.Session;
import com.facebook.presto.execution.ExecutionFailureInfo;
import com.facebook.presto.spark.classloader_interface.PrestoSparkFailure;
import com.facebook.presto.spark.classloader_interface.RetryExecutionStrategy;
import com.facebook.presto.spi.ErrorCode;
import com.google.common.collect.ImmutableList;

import javax.annotation.Nullable;

import java.util.List;
import java.util.Optional;

import static com.facebook.presto.execution.ExecutionFailureInfo.toStackTraceElement;
import static com.facebook.presto.spark.PrestoSparkSessionProperties.isRetryOnOutOfMemoryBroadcastJoinEnabled;
import static com.facebook.presto.spark.classloader_interface.RetryExecutionStrategy.DISABLE_BROADCAST_JOIN;
import static com.facebook.presto.spi.StandardErrorCode.EXCEEDED_LOCAL_BROADCAST_JOIN_MEMORY_LIMIT;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

public class PrestoSparkFailureUtils
{
private PrestoSparkFailureUtils() {}

public static PrestoSparkFailure toPrestoSparkFailure(Session session, ExecutionFailureInfo executionFailureInfo)
{
requireNonNull(executionFailureInfo, "executionFailureInfo is null");
PrestoSparkFailure prestoSparkFailure = toPrestoSparkFailure(executionFailureInfo);
checkState(prestoSparkFailure != null);

Optional<RetryExecutionStrategy> retryExecutionStrategy = getRetryExecutionStrategy(session, executionFailureInfo.getErrorCode(), executionFailureInfo.getMessage());
return new PrestoSparkFailure(
prestoSparkFailure.getMessage(),
prestoSparkFailure.getCause(),
prestoSparkFailure.getType(),
prestoSparkFailure.getErrorCode(),
retryExecutionStrategy);
}

@Nullable
private static PrestoSparkFailure toPrestoSparkFailure(ExecutionFailureInfo executionFailureInfo)
{
if (executionFailureInfo == null) {
return null;
}

PrestoSparkFailure prestoSparkFailure = new PrestoSparkFailure(
executionFailureInfo.getMessage(),
toPrestoSparkFailure(executionFailureInfo.getCause()),
executionFailureInfo.getType(),
executionFailureInfo.getErrorCode() == null ? "" : executionFailureInfo.getErrorCode().getName(),
Optional.empty());

for (ExecutionFailureInfo suppressed : executionFailureInfo.getSuppressed()) {
prestoSparkFailure.addSuppressed(requireNonNull(toPrestoSparkFailure(suppressed), "suppressed failure is null"));
}
ImmutableList.Builder<StackTraceElement> stackTraceBuilder = ImmutableList.builder();
for (String stack : executionFailureInfo.getStack()) {
stackTraceBuilder.add(toStackTraceElement(stack));
}
List<StackTraceElement> stackTrace = stackTraceBuilder.build();
prestoSparkFailure.setStackTrace(stackTrace.toArray(new StackTraceElement[stackTrace.size()]));
return prestoSparkFailure;
}

private static Optional<RetryExecutionStrategy> getRetryExecutionStrategy(Session session, ErrorCode errorCode, String message)
{
if (errorCode == null || message == null) {
return Optional.empty();
}

if (isRetryOnOutOfMemoryBroadcastJoinEnabled(session) && errorCode == EXCEEDED_LOCAL_BROADCAST_JOIN_MEMORY_LIMIT.toErrorCode()) {
return Optional.of(DISABLE_BROADCAST_JOIN);
}

return Optional.empty();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@
import com.facebook.presto.spark.classloader_interface.IPrestoSparkQueryExecutionFactory;
import com.facebook.presto.spark.classloader_interface.IPrestoSparkTaskExecutorFactory;
import com.facebook.presto.spark.classloader_interface.PrestoSparkConfInitializer;
import com.facebook.presto.spark.classloader_interface.PrestoSparkFailure;
import com.facebook.presto.spark.classloader_interface.PrestoSparkSession;
import com.facebook.presto.spark.classloader_interface.PrestoSparkTaskExecutorFactoryProvider;
import com.facebook.presto.spark.classloader_interface.RetryExecutionStrategy;
import com.facebook.presto.spi.Plugin;
import com.facebook.presto.spi.eventlistener.EventListener;
import com.facebook.presto.spi.function.FunctionImplementationType;
Expand Down Expand Up @@ -301,12 +303,12 @@ public PrestoSparkQueryRunner(String defaultCatalog, Map<String, String> additio
pluginManager.installPlugin(new HivePlugin("hive", Optional.of(metastore)));

Map<String, String> properties = ImmutableMap.<String, String>builder()
.put("hive.experimental-optimized-partition-update-serialization-enabled", "true")
.put("hive.allow-drop-table", "true")
.put("hive.allow-rename-table", "true")
.put("hive.allow-rename-column", "true")
.put("hive.allow-add-column", "true")
.put("hive.allow-drop-column", "true").build();
.put("hive.experimental-optimized-partition-update-serialization-enabled", "true")
.put("hive.allow-drop-table", "true")
.put("hive.allow-rename-table", "true")
.put("hive.allow-rename-column", "true")
.put("hive.allow-add-column", "true")
.put("hive.allow-drop-column", "true").build();

connectorManager.createConnection("hive", "hive", properties);

Expand Down Expand Up @@ -424,9 +426,28 @@ public MaterializedResult execute(String sql)
public MaterializedResult execute(Session session, String sql)
{
IPrestoSparkQueryExecutionFactory executionFactory = prestoSparkService.getQueryExecutionFactory();
try {
return execute(executionFactory, sparkContext, session, sql, Optional.empty());
}
catch (PrestoSparkFailure failure) {
if (failure.getRetryExecutionStrategy().isPresent()) {
return execute(executionFactory, sparkContext, session, sql, failure.getRetryExecutionStrategy());
}

throw failure;
}
}

private MaterializedResult execute(
IPrestoSparkQueryExecutionFactory executionFactory,
SparkContext sparkContext,
Session session,
String sql,
Optional<RetryExecutionStrategy> retryExecutionStrategy)
{
IPrestoSparkQueryExecution execution = executionFactory.create(
sparkContext,
createSessionInfo(session),
createSessionInfo(session, retryExecutionStrategy),
Optional.of(sql),
Optional.empty(),
Optional.empty(),
Expand All @@ -435,7 +456,9 @@ public MaterializedResult execute(Session session, String sql)
new TestingPrestoSparkTaskExecutorFactoryProvider(instanceId),
Optional.empty(),
Optional.empty());

List<List<Object>> results = execution.execute();

List<MaterializedRow> rows = results.stream()
.map(result -> new MaterializedRow(DEFAULT_PRECISION, result))
.collect(toImmutableList());
Expand All @@ -447,13 +470,13 @@ public MaterializedResult execute(Session session, String sql)
}
else {
return new MaterializedResult(
rows,
p.getOutputTypes(),
ImmutableMap.of(),
ImmutableSet.of(),
p.getUpdateType(),
OptionalLong.of((Long) getOnlyElement(getOnlyElement(rows).getFields())),
ImmutableList.of());
rows,
p.getOutputTypes(),
ImmutableMap.of(),
ImmutableSet.of(),
p.getUpdateType(),
OptionalLong.of((Long) getOnlyElement(getOnlyElement(rows).getFields())),
ImmutableList.of());
}
}
else {
Expand All @@ -468,7 +491,7 @@ public MaterializedResult execute(Session session, String sql)
}
}

private static PrestoSparkSession createSessionInfo(Session session)
private static PrestoSparkSession createSessionInfo(Session session, Optional<RetryExecutionStrategy> retryExecutionStrategy)
{
ImmutableMap.Builder<String, Map<String, String>> catalogSessionProperties = ImmutableMap.builder();
catalogSessionProperties.putAll(session.getConnectorProperties().entrySet().stream()
Expand All @@ -488,7 +511,8 @@ private static PrestoSparkSession createSessionInfo(Session session)
Optional.empty(),
session.getSystemProperties(),
catalogSessionProperties.build(),
session.getTraceToken());
session.getTraceToken(),
retryExecutionStrategy);
}

@Override
Expand Down
Loading

0 comments on commit af8feb8

Please sign in to comment.