Skip to content

Commit

Permalink
Spark: Create base classes for migration to JUnit5 (#9129)
Browse files Browse the repository at this point in the history
  • Loading branch information
tomtongue authored Nov 24, 2023
1 parent c427a56 commit b20d30c
Show file tree
Hide file tree
Showing 5 changed files with 535 additions and 62 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark;

import java.util.Map;
import java.util.stream.Stream;
import org.junit.jupiter.params.provider.Arguments;

public abstract class CatalogTestBase extends TestBaseWithCatalog {

// these parameters are broken out to avoid changes that need to modify lots of test suites
public static Stream<Arguments> parameters() {
return Stream.of(
Arguments.of(
SparkCatalogConfig.HIVE.catalogName(),
SparkCatalogConfig.HIVE.implementation(),
SparkCatalogConfig.HIVE.properties()),
Arguments.of(
SparkCatalogConfig.HADOOP.catalogName(),
SparkCatalogConfig.HADOOP.implementation(),
SparkCatalogConfig.HADOOP.properties()),
Arguments.of(
SparkCatalogConfig.SPARK.catalogName(),
SparkCatalogConfig.SPARK.implementation(),
SparkCatalogConfig.SPARK.properties()));
}

public CatalogTestBase(SparkCatalogConfig config) {
super(config);
}

public CatalogTestBase(String catalogName, String implementation, Map<String, String> config) {
super(catalogName, implementation, config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.spark.sql.Row;
import org.junit.Assert;
import org.assertj.core.api.Assertions;

public class SparkTestHelperBase {
protected static final Object ANY = new Object();
Expand Down Expand Up @@ -55,12 +55,13 @@ private Object[] toJava(Row row) {

protected void assertEquals(
String context, List<Object[]> expectedRows, List<Object[]> actualRows) {
Assert.assertEquals(
context + ": number of results should match", expectedRows.size(), actualRows.size());
Assertions.assertThat(actualRows)
.as(context + ": number of results should match")
.hasSameSizeAs(expectedRows);
for (int row = 0; row < expectedRows.size(); row += 1) {
Object[] expected = expectedRows.get(row);
Object[] actual = actualRows.get(row);
Assert.assertEquals("Number of columns should match", expected.length, actual.length);
Assertions.assertThat(actual).as("Number of columns should match").hasSameSizeAs(expected);
for (int col = 0; col < actualRows.get(row).length; col += 1) {
String newContext = String.format("%s: row %d col %d", context, row + 1, col + 1);
assertEquals(newContext, expected, actual);
Expand All @@ -69,19 +70,23 @@ protected void assertEquals(
}

protected void assertEquals(String context, Object[] expectedRow, Object[] actualRow) {
Assert.assertEquals("Number of columns should match", expectedRow.length, actualRow.length);
Assertions.assertThat(actualRow)
.as("Number of columns should match")
.hasSameSizeAs(expectedRow);
for (int col = 0; col < actualRow.length; col += 1) {
Object expectedValue = expectedRow[col];
Object actualValue = actualRow[col];
if (expectedValue != null && expectedValue.getClass().isArray()) {
String newContext = String.format("%s (nested col %d)", context, col + 1);
if (expectedValue instanceof byte[]) {
Assert.assertArrayEquals(newContext, (byte[]) expectedValue, (byte[]) actualValue);
Assertions.assertThat(actualValue).as(newContext).isEqualTo(expectedValue);
} else {
assertEquals(newContext, (Object[]) expectedValue, (Object[]) actualValue);
}
} else if (expectedValue != ANY) {
Assert.assertEquals(context + " contents should match", expectedValue, actualValue);
Assertions.assertThat(actualValue)
.as(context + " contents should match")
.isEqualTo(expectedValue);
}
}
}
Expand Down
287 changes: 287 additions & 0 deletions spark/v3.5/spark/src/test/java/org/apache/iceberg/spark/TestBase.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark;

import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.List;
import java.util.Map;
import java.util.TimeZone;
import java.util.concurrent.TimeoutException;
import java.util.concurrent.atomic.AtomicReference;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.iceberg.CatalogUtil;
import org.apache.iceberg.ContentFile;
import org.apache.iceberg.catalog.Namespace;
import org.apache.iceberg.exceptions.AlreadyExistsException;
import org.apache.iceberg.hive.HiveCatalog;
import org.apache.iceberg.hive.TestHiveMetastore;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.analysis.NoSuchTableException;
import org.apache.spark.sql.execution.QueryExecution;
import org.apache.spark.sql.execution.SparkPlan;
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec;
import org.apache.spark.sql.internal.SQLConf;
import org.apache.spark.sql.util.QueryExecutionListener;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;

public abstract class TestBase extends SparkTestHelperBase {

protected static TestHiveMetastore metastore = null;
protected static HiveConf hiveConf = null;
protected static SparkSession spark = null;
protected static JavaSparkContext sparkContext = null;
protected static HiveCatalog catalog = null;

@BeforeAll
public static void startMetastoreAndSpark() {
TestBase.metastore = new TestHiveMetastore();
metastore.start();
TestBase.hiveConf = metastore.hiveConf();

TestBase.spark =
SparkSession.builder()
.master("local[2]")
.config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic")
.config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname))
.config("spark.sql.legacy.respectNullabilityInTextDatasetConversion", "true")
.enableHiveSupport()
.getOrCreate();

TestBase.sparkContext = JavaSparkContext.fromSparkContext(spark.sparkContext());

TestBase.catalog =
(HiveCatalog)
CatalogUtil.loadCatalog(
HiveCatalog.class.getName(), "hive", ImmutableMap.of(), hiveConf);

try {
catalog.createNamespace(Namespace.of("default"));
} catch (AlreadyExistsException ignored) {
// the default namespace already exists. ignore the create error
}
}

@AfterAll
public static void stopMetastoreAndSpark() throws Exception {
TestBase.catalog = null;
if (metastore != null) {
metastore.stop();
TestBase.metastore = null;
}
if (spark != null) {
spark.stop();
TestBase.spark = null;
TestBase.sparkContext = null;
}
}

protected long waitUntilAfter(long timestampMillis) {
long current = System.currentTimeMillis();
while (current <= timestampMillis) {
current = System.currentTimeMillis();
}
return current;
}

protected List<Object[]> sql(String query, Object... args) {
List<Row> rows = spark.sql(String.format(query, args)).collectAsList();
if (rows.size() < 1) {
return ImmutableList.of();
}

return rowsToJava(rows);
}

protected Object scalarSql(String query, Object... args) {
List<Object[]> rows = sql(query, args);
Assertions.assertThat(rows.size()).as("Scalar SQL should return one row").isEqualTo(1);
Object[] row = Iterables.getOnlyElement(rows);
Assertions.assertThat(row.length).as("Scalar SQL should return one value").isEqualTo(1);
return row[0];
}

protected Object[] row(Object... values) {
return values;
}

protected static String dbPath(String dbName) {
return metastore.getDatabasePath(dbName);
}

protected void withUnavailableFiles(Iterable<? extends ContentFile<?>> files, Action action) {
Iterable<String> fileLocations = Iterables.transform(files, file -> file.path().toString());
withUnavailableLocations(fileLocations, action);
}

private void move(String location, String newLocation) {
Path path = Paths.get(URI.create(location));
Path tempPath = Paths.get(URI.create(newLocation));

try {
Files.move(path, tempPath);
} catch (IOException e) {
throw new UncheckedIOException("Failed to move: " + location, e);
}
}

protected void withUnavailableLocations(Iterable<String> locations, Action action) {
for (String location : locations) {
move(location, location + "_temp");
}

try {
action.invoke();
} finally {
for (String location : locations) {
move(location + "_temp", location);
}
}
}

protected void withDefaultTimeZone(String zoneId, Action action) {
TimeZone currentZone = TimeZone.getDefault();
try {
TimeZone.setDefault(TimeZone.getTimeZone(zoneId));
action.invoke();
} finally {
TimeZone.setDefault(currentZone);
}
}

protected void withSQLConf(Map<String, String> conf, Action action) {
SQLConf sqlConf = SQLConf.get();

Map<String, String> currentConfValues = Maps.newHashMap();
conf.keySet()
.forEach(
confKey -> {
if (sqlConf.contains(confKey)) {
String currentConfValue = sqlConf.getConfString(confKey);
currentConfValues.put(confKey, currentConfValue);
}
});

conf.forEach(
(confKey, confValue) -> {
if (SQLConf.isStaticConfigKey(confKey)) {
throw new RuntimeException("Cannot modify the value of a static config: " + confKey);
}
sqlConf.setConfString(confKey, confValue);
});

try {
action.invoke();
} finally {
conf.forEach(
(confKey, confValue) -> {
if (currentConfValues.containsKey(confKey)) {
sqlConf.setConfString(confKey, currentConfValues.get(confKey));
} else {
sqlConf.unsetConf(confKey);
}
});
}
}

protected Dataset<Row> jsonToDF(String schema, String... records) {
Dataset<String> jsonDF = spark.createDataset(ImmutableList.copyOf(records), Encoders.STRING());
return spark.read().schema(schema).json(jsonDF);
}

protected void append(String table, String... jsonRecords) {
try {
String schema = spark.table(table).schema().toDDL();
Dataset<Row> df = jsonToDF(schema, jsonRecords);
df.coalesce(1).writeTo(table).append();
} catch (NoSuchTableException e) {
throw new RuntimeException("Failed to write data", e);
}
}

protected String tablePropsAsString(Map<String, String> tableProps) {
StringBuilder stringBuilder = new StringBuilder();

for (Map.Entry<String, String> property : tableProps.entrySet()) {
if (stringBuilder.length() > 0) {
stringBuilder.append(", ");
}
stringBuilder.append(String.format("'%s' '%s'", property.getKey(), property.getValue()));
}

return stringBuilder.toString();
}

protected SparkPlan executeAndKeepPlan(String query, Object... args) {
return executeAndKeepPlan(() -> sql(query, args));
}

protected SparkPlan executeAndKeepPlan(Action action) {
AtomicReference<SparkPlan> executedPlanRef = new AtomicReference<>();

QueryExecutionListener listener =
new QueryExecutionListener() {
@Override
public void onSuccess(String funcName, QueryExecution qe, long durationNs) {
executedPlanRef.set(qe.executedPlan());
}

@Override
public void onFailure(String funcName, QueryExecution qe, Exception exception) {}
};

spark.listenerManager().register(listener);

action.invoke();

try {
spark.sparkContext().listenerBus().waitUntilEmpty();
} catch (TimeoutException e) {
throw new RuntimeException("Timeout while waiting for processing events", e);
}

SparkPlan executedPlan = executedPlanRef.get();
if (executedPlan instanceof AdaptiveSparkPlanExec) {
return ((AdaptiveSparkPlanExec) executedPlan).executedPlan();
} else {
return executedPlan;
}
}

@FunctionalInterface
protected interface Action {
void invoke();
}
}
Loading

0 comments on commit b20d30c

Please sign in to comment.