diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala index 83bb093a3..0511a082f 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/DataLocation.scala @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._ new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"), new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"), new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"), + new JsonSubTypes.Type(value = classOf[SparkSqlLocation], name = "sparksql"), new JsonSubTypes.Type(value = classOf[Snowflake], name = "snowflake"), )) trait DataLocation { diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/SparkSqlLocation.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/SparkSqlLocation.scala new file mode 100644 index 000000000..134ade5f6 --- /dev/null +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/config/location/SparkSqlLocation.scala @@ -0,0 +1,69 @@ +package com.linkedin.feathr.offline.config.location + +import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize +import com.linkedin.feathr.common.Header +import org.apache.spark.sql.{DataFrame, SparkSession} + +@CaseClassDeserialize() +case class SparkSqlLocation(sql: Option[String] = None, table: Option[String] = None) extends DataLocation { + /** + * Backward Compatibility + * Many existing codes expect a simple path + * + * @return the `path` or `url` of the data source + * + * WARN: This method is deprecated, you must use match/case on InputLocation, + * and get `path` from `SimplePath` only + */ + override def getPath: String = sql.getOrElse(table.getOrElse("")) + + /** + * Backward Compatibility + * + * @return the `path` or `url` of the data source, wrapped in an List + * + * WARN: This method is deprecated, you must use match/case on InputLocation, + * and get `paths` from `PathList` only + */ + override def getPathList: List[String] = List(getPath) + + /** + * Load DataFrame from Spark session + * + * @param ss SparkSession + * @return + */ + override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = { + sql match { + case Some(sql) => { + ss.sql(sql) + } + case None => table match { + case Some(table) => { + ss.sqlContext.table(table) + } + case None => { + throw new IllegalArgumentException("`sql` and `table` parameter should not both be empty") + } + } + } + } + + /** + * Write DataFrame to the location + * + * NOTE: SparkSqlLocation doesn't support writing, it cannot be used as data sink + * + * @param ss SparkSession + * @param df DataFrame to write + */ + override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ??? + + /** + * Tell if this location is file based + * + * @return boolean + */ + override def isFileBasedLocation(): Boolean = false +} + diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala index 181feefff..4860f93ba 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/source/accessor/NonTimeBasedDataSourceAccessor.scala @@ -1,6 +1,6 @@ package com.linkedin.feathr.offline.source.accessor -import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, PathList, SimplePath, Snowflake} +import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, PathList, SimplePath, Snowflake, SparkSqlLocation} import com.linkedin.feathr.offline.source.DataSource import com.linkedin.feathr.offline.source.dataloader.{CaseInsensitiveGenericRecordWrapper, DataLoaderFactory} import com.linkedin.feathr.offline.testfwk.TestFwkUtils @@ -35,6 +35,7 @@ private[offline] class NonTimeBasedDataSourceAccessor( case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y)) case Jdbc(_, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) + case SparkSqlLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case Snowflake(_, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate()) case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame() } diff --git a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/config/location/TestSparkSqlLocation.scala b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/config/location/TestSparkSqlLocation.scala new file mode 100644 index 000000000..35e1d876f --- /dev/null +++ b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/config/location/TestSparkSqlLocation.scala @@ -0,0 +1,45 @@ +package com.linkedin.feathr.offline.config.location + +import com.linkedin.feathr.offline.TestFeathr +import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor +import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType} +import org.apache.spark.sql.Row +import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType} +import org.testng.Assert.assertEquals +import org.testng.annotations.{BeforeClass, Test} + +import scala.collection.JavaConversions._ + +class TestSparkSqlLocation extends TestFeathr{ + @BeforeClass + override def setup(): Unit = { + ss = TestFeathr.getOrCreateSparkSessionWithHive + super.setup() + } + + @Test(description = "It should read a Spark SQL as the DataSource") + def testCreateWithFixedPath(): Unit = { + val schema = StructType(Array( + StructField("language", StringType, true), + StructField("users", IntegerType, true) + )) + val rowData= Seq(Row("Java", 20000), + Row("Python", 100000), + Row("Scala", 3000)) + val tempDf = ss.createDataFrame(rowData, schema) + tempDf.createTempView("test_spark_sql_table") + + val loc = SparkSqlLocation(sql = Some("select * from test_spark_sql_table order by users asc")) + val dataSource = DataSource(loc, SourceFormatType.FIXED_PATH) + val accessor = DataSourceAccessor(ss, dataSource, None, None, failOnMissingPartition = false, dataPathHandlers=List()) + val df = accessor.get() + val expectedRows = Array( + Row("Scala", 3000), + Row("Java", 20000), + Row("Python", 100000) + ) + assertEquals(df.collect(), expectedRows) + + ss.sqlContext.dropTempTable("test_spark_sql_table") + } +} diff --git a/feathr_project/feathr/definition/source.py b/feathr_project/feathr/definition/source.py index 676a5cb76..1984c8fe5 100644 --- a/feathr_project/feathr/definition/source.py +++ b/feathr_project/feathr/definition/source.py @@ -365,6 +365,74 @@ def to_feature_config(self) -> str: def to_argument(self): raise TypeError("KafKaSource cannot be used as observation source") +class SparkSqlSource(Source): + def __init__(self, name: str, sql: Optional[str] = None, table: Optional[str] = None, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None: + """ SparkSqlSource can use either a sql query or a table name as the source for Feathr job. + name: name of the source + sql: sql query to use as the source, either sql or table must be specified + table: table name to use as the source, either sql or table must be specified + preprocessing (Optional[Callable]): A preprocessing python function that transforms the source data for further feature transformation. + event_timestamp_column (Optional[str]): The timestamp field of your record. As sliding window aggregation feature assume each record in the source data should have a timestamp column. + timestamp_format (Optional[str], optional): The format of the timestamp field. Defaults to "epoch". Possible values are: + - `epoch` (seconds since epoch), for example `1647737463` + - `epoch_millis` (milliseconds since epoch), for example `1647737517761` + - Any date formats supported by [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html). + registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc. + """ + super().__init__(name, event_timestamp_column, + timestamp_format, registry_tags=registry_tags) + self.source_type = 'sparksql' + if sql is None and table is None: + raise ValueError("Either `sql` or `table` must be specified") + if sql is not None and table is not None: + raise ValueError("Only one of `sql` or `table` can be specified") + if sql is not None: + self.sql = sql + if table is not None: + self.table = table + self.preprocessing = preprocessing + + def to_feature_config(self) -> str: + tm = Template(""" + {{source.name}}: { + location: { + type: "sparksql" + {% if source.sql is defined %} + sql: "{{source.sql}}" + {% elif source.table is defined %} + table: "{{source.table}}" + {% endif %} + } + {% if source.event_timestamp_column is defined %} + timeWindowParameters: { + timestampColumn: "{{source.event_timestamp_column}}" + timestampColumnFormat: "{{source.timestamp_format}}" + } + {% endif %} + } + """) + msg = tm.render(source=self) + return msg + + def get_required_properties(self): + return [] + + def to_dict(self) -> Dict[str, str]: + ret = self.options.copy() + ret["type"] = "sparksql" + if self.sql: + ret["sql"] = self.sql + elif self.table: + ret["table"] = self.table + return ret + + def to_argument(self): + """ + One-line JSON string, used by job submitter + """ + return json.dumps(self.to_dict()) + + class GenericSource(Source): """ This class is corresponding to 'GenericLocation' in Feathr core, but only be used as Source. diff --git a/feathr_project/test/test_spark_sql_source.py b/feathr_project/test/test_spark_sql_source.py new file mode 100644 index 000000000..0189302b9 --- /dev/null +++ b/feathr_project/test/test_spark_sql_source.py @@ -0,0 +1,159 @@ +import os +from datetime import datetime, timedelta +from pathlib import Path +from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32, + DerivedFeature, Feature, FeatureAnchor, + TypedKey, ValueType, WindowAggTransformation) + +import pytest +from feathr import FeathrClient +from feathr import FeatureQuery +from feathr import ObservationSettings +from feathr import TypedKey +from feathr import ValueType +from feathr.definition.materialization_settings import BackfillTime, MaterializationSettings +from feathr.definition.sink import HdfsSink +from feathr.utils.job_utils import get_result_df +from test_utils.constants import Constants +from feathr.definition.source import SparkSqlSource + + +@pytest.mark.skipif(os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') != "databricks", + reason="this test uses predefined Databricks table.") +def test_feathr_spark_sql_query_source(): + test_workspace_dir = Path( + __file__).parent.resolve() / "test_user_workspace" + config_path = os.path.join(test_workspace_dir, "feathr_config.yaml") + + _get_offline_features(config_path, _sql_query_source()) + _get_offline_features(config_path, _sql_table_source()) + _materialize_to_offline(config_path, _sql_query_source()) + + +def _get_offline_features(config_path: str, sql_source: SparkSqlSource): + client: FeathrClient = _spark_sql_test_setup(config_path, sql_source) + location_id = TypedKey(key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id") + + feature_query = FeatureQuery( + feature_list=["f_location_avg_fare"], key=location_id) + settings = ObservationSettings( + observation_path="wasbs://public@azurefeathrstorage.blob.core.windows.net/sample_data/green_tripdata_2020-04.csv", + event_timestamp_column="lpep_dropoff_datetime", + timestamp_format="yyyy-MM-dd HH:mm:ss") + + now = datetime.now() + output_path = ''.join(['dbfs:/feathrazure_cijob_sparksql', + '_', str(now.minute), '_', str(now.second), ".avro"]) + + client.get_offline_features(observation_settings=settings, + feature_query=feature_query, + output_path=output_path) + + # assuming the job can successfully run; otherwise it will throw exception + client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) + return + + +def _materialize_to_offline(config_path: str, sql_source: SparkSqlSource): + client: FeathrClient = _spark_sql_test_setup(config_path, sql_source) + backfill_time = BackfillTime(start=datetime( + 2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1)) + + now = datetime.now() + if client.spark_runtime == 'databricks': + output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_sparksql', + '_', str(now.minute), '_', str(now.second), ""]) + else: + output_path = ''.join(['abfss://feathrazuretest3fs@feathrazuretest3storage.dfs.core.windows.net/demo_data/feathrazure_cijob_materialize_offline_sparksql', + '_', str(now.minute), '_', str(now.second), ""]) + offline_sink = HdfsSink(output_path=output_path) + settings = MaterializationSettings("nycTaxiTable", + sinks=[offline_sink], + feature_names=[ + "f_location_avg_fare", "f_location_max_fare"], + backfill_time=backfill_time) + client.materialize_features(settings) + # assuming the job can successfully run; otherwise it will throw exception + client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS) + + # download result and just assert the returned result is not empty + # by default, it will write to a folder appended with date + res_df = get_result_df( + client, "avro", output_path + "/df0/daily/2020/05/20") + assert res_df.shape[0] > 0 + + +def _spark_sql_test_setup(config_path: str, sql_source: SparkSqlSource): + client = FeathrClient(config_path=config_path) + + f_trip_distance = Feature(name="f_trip_distance", + feature_type=FLOAT, transform="trip_distance") + f_trip_time_duration = Feature(name="f_trip_time_duration", + feature_type=INT32, + transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60") + + features = [ + f_trip_distance, + f_trip_time_duration, + Feature(name="f_is_long_trip_distance", + feature_type=BOOLEAN, + transform="cast_float(trip_distance)>30"), + Feature(name="f_day_of_week", + feature_type=INT32, + transform="dayofweek(lpep_dropoff_datetime)"), + ] + + request_anchor = FeatureAnchor(name="request_features", + source=INPUT_CONTEXT, + features=features) + + f_trip_time_distance = DerivedFeature(name="f_trip_time_distance", + feature_type=FLOAT, + input_features=[ + f_trip_distance, f_trip_time_duration], + transform="f_trip_distance * f_trip_time_duration") + + f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded", + feature_type=INT32, + input_features=[f_trip_time_duration], + transform="f_trip_time_duration % 10") + + location_id = TypedKey(key_column="DOLocationID", + key_column_type=ValueType.INT32, + description="location id in NYC", + full_name="nyc_taxi.location_id") + agg_features = [Feature(name="f_location_avg_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", + agg_func="AVG", + window="90d", + filter="fare_amount > 0" + )), + Feature(name="f_location_max_fare", + key=location_id, + feature_type=FLOAT, + transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)", + agg_func="MAX", + window="90d")) + ] + + agg_anchor = FeatureAnchor(name="aggregationFeatures", + source=sql_source, + features=agg_features) + + client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[ + f_trip_time_distance, f_trip_time_rounded]) + + return client + + +def _sql_query_source(): + return SparkSqlSource(name="sparkSqlQuerySource", sql="SELECT * FROM green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss") + + +def _sql_table_source(): + return SparkSqlSource(name="sparkSqlTableSource", table="green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss")