-
Notifications
You must be signed in to change notification settings - Fork 229
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Spark SQL source * Add unit test * Clean up after test * Typo * Fix conflict * Add doc comments * add spark sql source test cases Signed-off-by: Yuqing Wei <[email protected]> * add spark sql source test cases Signed-off-by: Yuqing Wei <[email protected]> * Ignore sql if it is None Signed-off-by: Yuqing Wei <[email protected]> Co-authored-by: Yuqing Wei <[email protected]>
- Loading branch information
1 parent
01ea96d
commit 308321d
Showing
6 changed files
with
344 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
69 changes: 69 additions & 0 deletions
69
...hr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/SparkSqlLocation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
package com.linkedin.feathr.offline.config.location | ||
|
||
import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize | ||
import com.linkedin.feathr.common.Header | ||
import org.apache.spark.sql.{DataFrame, SparkSession} | ||
|
||
@CaseClassDeserialize() | ||
case class SparkSqlLocation(sql: Option[String] = None, table: Option[String] = None) extends DataLocation { | ||
/** | ||
* Backward Compatibility | ||
* Many existing codes expect a simple path | ||
* | ||
* @return the `path` or `url` of the data source | ||
* | ||
* WARN: This method is deprecated, you must use match/case on InputLocation, | ||
* and get `path` from `SimplePath` only | ||
*/ | ||
override def getPath: String = sql.getOrElse(table.getOrElse("")) | ||
|
||
/** | ||
* Backward Compatibility | ||
* | ||
* @return the `path` or `url` of the data source, wrapped in an List | ||
* | ||
* WARN: This method is deprecated, you must use match/case on InputLocation, | ||
* and get `paths` from `PathList` only | ||
*/ | ||
override def getPathList: List[String] = List(getPath) | ||
|
||
/** | ||
* Load DataFrame from Spark session | ||
* | ||
* @param ss SparkSession | ||
* @return | ||
*/ | ||
override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = { | ||
sql match { | ||
case Some(sql) => { | ||
ss.sql(sql) | ||
} | ||
case None => table match { | ||
case Some(table) => { | ||
ss.sqlContext.table(table) | ||
} | ||
case None => { | ||
throw new IllegalArgumentException("`sql` and `table` parameter should not both be empty") | ||
} | ||
} | ||
} | ||
} | ||
|
||
/** | ||
* Write DataFrame to the location | ||
* | ||
* NOTE: SparkSqlLocation doesn't support writing, it cannot be used as data sink | ||
* | ||
* @param ss SparkSession | ||
* @param df DataFrame to write | ||
*/ | ||
override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? | ||
|
||
/** | ||
* Tell if this location is file based | ||
* | ||
* @return boolean | ||
*/ | ||
override def isFileBasedLocation(): Boolean = false | ||
} | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
45 changes: 45 additions & 0 deletions
45
...mpl/src/test/scala/com/linkedin/feathr/offline/config/location/TestSparkSqlLocation.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
package com.linkedin.feathr.offline.config.location | ||
|
||
import com.linkedin.feathr.offline.TestFeathr | ||
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor | ||
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType} | ||
import org.apache.spark.sql.Row | ||
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} | ||
import org.testng.Assert.assertEquals | ||
import org.testng.annotations.{BeforeClass, Test} | ||
|
||
import scala.collection.JavaConversions._ | ||
|
||
class TestSparkSqlLocation extends TestFeathr{ | ||
@BeforeClass | ||
override def setup(): Unit = { | ||
ss = TestFeathr.getOrCreateSparkSessionWithHive | ||
super.setup() | ||
} | ||
|
||
@Test(description = "It should read a Spark SQL as the DataSource") | ||
def testCreateWithFixedPath(): Unit = { | ||
val schema = StructType(Array( | ||
StructField("language", StringType, true), | ||
StructField("users", IntegerType, true) | ||
)) | ||
val rowData= Seq(Row("Java", 20000), | ||
Row("Python", 100000), | ||
Row("Scala", 3000)) | ||
val tempDf = ss.createDataFrame(rowData, schema) | ||
tempDf.createTempView("test_spark_sql_table") | ||
|
||
val loc = SparkSqlLocation(sql = Some("select * from test_spark_sql_table order by users asc")) | ||
val dataSource = DataSource(loc, SourceFormatType.FIXED_PATH) | ||
val accessor = DataSourceAccessor(ss, dataSource, None, None, failOnMissingPartition = false, dataPathHandlers=List()) | ||
val df = accessor.get() | ||
val expectedRows = Array( | ||
Row("Scala", 3000), | ||
Row("Java", 20000), | ||
Row("Python", 100000) | ||
) | ||
assertEquals(df.collect(), expectedRows) | ||
|
||
ss.sqlContext.dropTempTable("test_spark_sql_table") | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,159 @@ | ||
import os | ||
from datetime import datetime, timedelta | ||
from pathlib import Path | ||
from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32, | ||
DerivedFeature, Feature, FeatureAnchor, | ||
TypedKey, ValueType, WindowAggTransformation) | ||
|
||
import pytest | ||
from feathr import FeathrClient | ||
from feathr import FeatureQuery | ||
from feathr import ObservationSettings | ||
from feathr import TypedKey | ||
from feathr import ValueType | ||
from feathr.definition.materialization_settings import BackfillTime, MaterializationSettings | ||
from feathr.definition.sink import HdfsSink | ||
from feathr.utils.job_utils import get_result_df | ||
from test_utils.constants import Constants | ||
from feathr.definition.source import SparkSqlSource | ||
|
||
|
||
@pytest.mark.skipif(os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') != "databricks", | ||
reason="this test uses predefined Databricks table.") | ||
def test_feathr_spark_sql_query_source(): | ||
test_workspace_dir = Path( | ||
__file__).parent.resolve() / "test_user_workspace" | ||
config_path = os.path.join(test_workspace_dir, "feathr_config.yaml") | ||
|
||
_get_offline_features(config_path, _sql_query_source()) | ||
_get_offline_features(config_path, _sql_table_source()) | ||
_materialize_to_offline(config_path, _sql_query_source()) | ||
|
||
|
||
def _get_offline_features(config_path: str, sql_source: SparkSqlSource): | ||
client: FeathrClient = _spark_sql_test_setup(config_path, sql_source) | ||
location_id = TypedKey(key_column="DOLocationID", | ||
key_column_type=ValueType.INT32, | ||
description="location id in NYC", | ||
full_name="nyc_taxi.location_id") | ||
|
||
feature_query = FeatureQuery( | ||
feature_list=["f_location_avg_fare"], key=location_id) | ||
settings = ObservationSettings( | ||
observation_path="wasbs://[email protected]/sample_data/green_tripdata_2020-04.csv", | ||
event_timestamp_column="lpep_dropoff_datetime", | ||
timestamp_format="yyyy-MM-dd HH:mm:ss") | ||
|
||
now = datetime.now() | ||
output_path = ''.join(['dbfs:/feathrazure_cijob_sparksql', | ||
'_', str(now.minute), '_', str(now.second), ".avro"]) | ||
|
||
client.get_offline_features(observation_settings=settings, | ||
feature_query=feature_query, | ||
output_path=output_path) | ||
|
||
# assuming the job can successfully run; otherwise it will throw exception | ||
client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) | ||
return | ||
|
||
|
||
def _materialize_to_offline(config_path: str, sql_source: SparkSqlSource): | ||
client: FeathrClient = _spark_sql_test_setup(config_path, sql_source) | ||
backfill_time = BackfillTime(start=datetime( | ||
2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) | ||
|
||
now = datetime.now() | ||
if client.spark_runtime == 'databricks': | ||
output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_sparksql', | ||
'_', str(now.minute), '_', str(now.second), ""]) | ||
else: | ||
output_path = ''.join(['abfss://[email protected]/demo_data/feathrazure_cijob_materialize_offline_sparksql', | ||
'_', str(now.minute), '_', str(now.second), ""]) | ||
offline_sink = HdfsSink(output_path=output_path) | ||
settings = MaterializationSettings("nycTaxiTable", | ||
sinks=[offline_sink], | ||
feature_names=[ | ||
"f_location_avg_fare", "f_location_max_fare"], | ||
backfill_time=backfill_time) | ||
client.materialize_features(settings) | ||
# assuming the job can successfully run; otherwise it will throw exception | ||
client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) | ||
|
||
# download result and just assert the returned result is not empty | ||
# by default, it will write to a folder appended with date | ||
res_df = get_result_df( | ||
client, "avro", output_path + "/df0/daily/2020/05/20") | ||
assert res_df.shape[0] > 0 | ||
|
||
|
||
def _spark_sql_test_setup(config_path: str, sql_source: SparkSqlSource): | ||
client = FeathrClient(config_path=config_path) | ||
|
||
f_trip_distance = Feature(name="f_trip_distance", | ||
feature_type=FLOAT, transform="trip_distance") | ||
f_trip_time_duration = Feature(name="f_trip_time_duration", | ||
feature_type=INT32, | ||
transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") | ||
|
||
features = [ | ||
f_trip_distance, | ||
f_trip_time_duration, | ||
Feature(name="f_is_long_trip_distance", | ||
feature_type=BOOLEAN, | ||
transform="cast_float(trip_distance)>30"), | ||
Feature(name="f_day_of_week", | ||
feature_type=INT32, | ||
transform="dayofweek(lpep_dropoff_datetime)"), | ||
] | ||
|
||
request_anchor = FeatureAnchor(name="request_features", | ||
source=INPUT_CONTEXT, | ||
features=features) | ||
|
||
f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", | ||
feature_type=FLOAT, | ||
input_features=[ | ||
f_trip_distance, f_trip_time_duration], | ||
transform="f_trip_distance * f_trip_time_duration") | ||
|
||
f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", | ||
feature_type=INT32, | ||
input_features=[f_trip_time_duration], | ||
transform="f_trip_time_duration % 10") | ||
|
||
location_id = TypedKey(key_column="DOLocationID", | ||
key_column_type=ValueType.INT32, | ||
description="location id in NYC", | ||
full_name="nyc_taxi.location_id") | ||
agg_features = [Feature(name="f_location_avg_fare", | ||
key=location_id, | ||
feature_type=FLOAT, | ||
transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", | ||
agg_func="AVG", | ||
window="90d", | ||
filter="fare_amount > 0" | ||
)), | ||
Feature(name="f_location_max_fare", | ||
key=location_id, | ||
feature_type=FLOAT, | ||
transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", | ||
agg_func="MAX", | ||
window="90d")) | ||
] | ||
|
||
agg_anchor = FeatureAnchor(name="aggregationFeatures", | ||
source=sql_source, | ||
features=agg_features) | ||
|
||
client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ | ||
f_trip_time_distance, f_trip_time_rounded]) | ||
|
||
return client | ||
|
||
|
||
def _sql_query_source(): | ||
return SparkSqlSource(name="sparkSqlQuerySource", sql="SELECT * FROM green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss") | ||
|
||
|
||
def _sql_table_source(): | ||
return SparkSqlSource(name="sparkSqlTableSource", table="green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss") |