Skip to content

Commit

Permalink
Windoze/spark sql source (#839)
Browse files Browse the repository at this point in the history
* Spark SQL source

* Add unit test

* Clean up after test

* Typo

* Fix conflict

* Add doc comments

* add spark sql source test cases

Signed-off-by: Yuqing Wei <[email protected]>

* add spark sql source test cases

Signed-off-by: Yuqing Wei <[email protected]>

* Ignore sql if it is None

Signed-off-by: Yuqing Wei <[email protected]>
Co-authored-by: Yuqing Wei <[email protected]>
  • Loading branch information
windoze and Yuqing-cat authored Jan 6, 2023
1 parent 01ea96d commit 308321d
Show file tree
Hide file tree
Showing 6 changed files with 344 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"),
new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"),
new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"),
new JsonSubTypes.Type(value = classOf[SparkSqlLocation], name = "sparksql"),
new JsonSubTypes.Type(value = classOf[Snowflake], name = "snowflake"),
))
trait DataLocation {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.linkedin.feathr.offline.config.location

import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize
import com.linkedin.feathr.common.Header
import org.apache.spark.sql.{DataFrame, SparkSession}

@CaseClassDeserialize()
case class SparkSqlLocation(sql: Option[String] = None, table: Option[String] = None) extends DataLocation {
/**
* Backward Compatibility
* Many existing codes expect a simple path
*
* @return the `path` or `url` of the data source
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `path` from `SimplePath` only
*/
override def getPath: String = sql.getOrElse(table.getOrElse(""))

/**
* Backward Compatibility
*
* @return the `path` or `url` of the data source, wrapped in an List
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `paths` from `PathList` only
*/
override def getPathList: List[String] = List(getPath)

/**
* Load DataFrame from Spark session
*
* @param ss SparkSession
* @return
*/
override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = {
sql match {
case Some(sql) => {
ss.sql(sql)
}
case None => table match {
case Some(table) => {
ss.sqlContext.table(table)
}
case None => {
throw new IllegalArgumentException("`sql` and `table` parameter should not both be empty")
}
}
}
}

/**
* Write DataFrame to the location
*
* NOTE: SparkSqlLocation doesn't support writing, it cannot be used as data sink
*
* @param ss SparkSession
* @param df DataFrame to write
*/
override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ???

/**
* Tell if this location is file based
*
* @return boolean
*/
override def isFileBasedLocation(): Boolean = false
}

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.linkedin.feathr.offline.source.accessor

import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, PathList, SimplePath, Snowflake}
import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, PathList, SimplePath, Snowflake, SparkSqlLocation}
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.dataloader.{CaseInsensitiveGenericRecordWrapper, DataLoaderFactory}
import com.linkedin.feathr.offline.testfwk.TestFwkUtils
Expand Down Expand Up @@ -35,6 +35,7 @@ private[offline] class NonTimeBasedDataSourceAccessor(
case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y))
case Jdbc(_, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case SparkSqlLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case Snowflake(_, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.linkedin.feathr.offline.config.location

import com.linkedin.feathr.offline.TestFeathr
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.testng.Assert.assertEquals
import org.testng.annotations.{BeforeClass, Test}

import scala.collection.JavaConversions._

class TestSparkSqlLocation extends TestFeathr{
@BeforeClass
override def setup(): Unit = {
ss = TestFeathr.getOrCreateSparkSessionWithHive
super.setup()
}

@Test(description = "It should read a Spark SQL as the DataSource")
def testCreateWithFixedPath(): Unit = {
val schema = StructType(Array(
StructField("language", StringType, true),
StructField("users", IntegerType, true)
))
val rowData= Seq(Row("Java", 20000),
Row("Python", 100000),
Row("Scala", 3000))
val tempDf = ss.createDataFrame(rowData, schema)
tempDf.createTempView("test_spark_sql_table")

val loc = SparkSqlLocation(sql = Some("select * from test_spark_sql_table order by users asc"))
val dataSource = DataSource(loc, SourceFormatType.FIXED_PATH)
val accessor = DataSourceAccessor(ss, dataSource, None, None, failOnMissingPartition = false, dataPathHandlers=List())
val df = accessor.get()
val expectedRows = Array(
Row("Scala", 3000),
Row("Java", 20000),
Row("Python", 100000)
)
assertEquals(df.collect(), expectedRows)

ss.sqlContext.dropTempTable("test_spark_sql_table")
}
}
68 changes: 68 additions & 0 deletions feathr_project/feathr/definition/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,74 @@ def to_feature_config(self) -> str:
def to_argument(self):
raise TypeError("KafKaSource cannot be used as observation source")

class SparkSqlSource(Source):
def __init__(self, name: str, sql: Optional[str] = None, table: Optional[str] = None, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None:
""" SparkSqlSource can use either a sql query or a table name as the source for Feathr job.
name: name of the source
sql: sql query to use as the source, either sql or table must be specified
table: table name to use as the source, either sql or table must be specified
preprocessing (Optional[Callable]): A preprocessing python function that transforms the source data for further feature transformation.
event_timestamp_column (Optional[str]): The timestamp field of your record. As sliding window aggregation feature assume each record in the source data should have a timestamp column.
timestamp_format (Optional[str], optional): The format of the timestamp field. Defaults to "epoch". Possible values are:
- `epoch` (seconds since epoch), for example `1647737463`
- `epoch_millis` (milliseconds since epoch), for example `1647737517761`
- Any date formats supported by [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html).
registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc.
"""
super().__init__(name, event_timestamp_column,
timestamp_format, registry_tags=registry_tags)
self.source_type = 'sparksql'
if sql is None and table is None:
raise ValueError("Either `sql` or `table` must be specified")
if sql is not None and table is not None:
raise ValueError("Only one of `sql` or `table` can be specified")
if sql is not None:
self.sql = sql
if table is not None:
self.table = table
self.preprocessing = preprocessing

def to_feature_config(self) -> str:
tm = Template("""
{{source.name}}: {
location: {
type: "sparksql"
{% if source.sql is defined %}
sql: "{{source.sql}}"
{% elif source.table is defined %}
table: "{{source.table}}"
{% endif %}
}
{% if source.event_timestamp_column is defined %}
timeWindowParameters: {
timestampColumn: "{{source.event_timestamp_column}}"
timestampColumnFormat: "{{source.timestamp_format}}"
}
{% endif %}
}
""")
msg = tm.render(source=self)
return msg

def get_required_properties(self):
return []

def to_dict(self) -> Dict[str, str]:
ret = self.options.copy()
ret["type"] = "sparksql"
if self.sql:
ret["sql"] = self.sql
elif self.table:
ret["table"] = self.table
return ret

def to_argument(self):
"""
One-line JSON string, used by job submitter
"""
return json.dumps(self.to_dict())


class GenericSource(Source):
"""
This class is corresponding to 'GenericLocation' in Feathr core, but only be used as Source.
Expand Down
159 changes: 159 additions & 0 deletions feathr_project/test/test_spark_sql_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import os
from datetime import datetime, timedelta
from pathlib import Path
from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32,
DerivedFeature, Feature, FeatureAnchor,
TypedKey, ValueType, WindowAggTransformation)

import pytest
from feathr import FeathrClient
from feathr import FeatureQuery
from feathr import ObservationSettings
from feathr import TypedKey
from feathr import ValueType
from feathr.definition.materialization_settings import BackfillTime, MaterializationSettings
from feathr.definition.sink import HdfsSink
from feathr.utils.job_utils import get_result_df
from test_utils.constants import Constants
from feathr.definition.source import SparkSqlSource


@pytest.mark.skipif(os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') != "databricks",
reason="this test uses predefined Databricks table.")
def test_feathr_spark_sql_query_source():
test_workspace_dir = Path(
__file__).parent.resolve() / "test_user_workspace"
config_path = os.path.join(test_workspace_dir, "feathr_config.yaml")

_get_offline_features(config_path, _sql_query_source())
_get_offline_features(config_path, _sql_table_source())
_materialize_to_offline(config_path, _sql_query_source())


def _get_offline_features(config_path: str, sql_source: SparkSqlSource):
client: FeathrClient = _spark_sql_test_setup(config_path, sql_source)
location_id = TypedKey(key_column="DOLocationID",
key_column_type=ValueType.INT32,
description="location id in NYC",
full_name="nyc_taxi.location_id")

feature_query = FeatureQuery(
feature_list=["f_location_avg_fare"], key=location_id)
settings = ObservationSettings(
observation_path="wasbs://[email protected]/sample_data/green_tripdata_2020-04.csv",
event_timestamp_column="lpep_dropoff_datetime",
timestamp_format="yyyy-MM-dd HH:mm:ss")

now = datetime.now()
output_path = ''.join(['dbfs:/feathrazure_cijob_sparksql',
'_', str(now.minute), '_', str(now.second), ".avro"])

client.get_offline_features(observation_settings=settings,
feature_query=feature_query,
output_path=output_path)

# assuming the job can successfully run; otherwise it will throw exception
client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS)
return


def _materialize_to_offline(config_path: str, sql_source: SparkSqlSource):
client: FeathrClient = _spark_sql_test_setup(config_path, sql_source)
backfill_time = BackfillTime(start=datetime(
2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))

now = datetime.now()
if client.spark_runtime == 'databricks':
output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_sparksql',
'_', str(now.minute), '_', str(now.second), ""])
else:
output_path = ''.join(['abfss://[email protected]/demo_data/feathrazure_cijob_materialize_offline_sparksql',
'_', str(now.minute), '_', str(now.second), ""])
offline_sink = HdfsSink(output_path=output_path)
settings = MaterializationSettings("nycTaxiTable",
sinks=[offline_sink],
feature_names=[
"f_location_avg_fare", "f_location_max_fare"],
backfill_time=backfill_time)
client.materialize_features(settings)
# assuming the job can successfully run; otherwise it will throw exception
client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS)

# download result and just assert the returned result is not empty
# by default, it will write to a folder appended with date
res_df = get_result_df(
client, "avro", output_path + "/df0/daily/2020/05/20")
assert res_df.shape[0] > 0


def _spark_sql_test_setup(config_path: str, sql_source: SparkSqlSource):
client = FeathrClient(config_path=config_path)

f_trip_distance = Feature(name="f_trip_distance",
feature_type=FLOAT, transform="trip_distance")
f_trip_time_duration = Feature(name="f_trip_time_duration",
feature_type=INT32,
transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60")

features = [
f_trip_distance,
f_trip_time_duration,
Feature(name="f_is_long_trip_distance",
feature_type=BOOLEAN,
transform="cast_float(trip_distance)>30"),
Feature(name="f_day_of_week",
feature_type=INT32,
transform="dayofweek(lpep_dropoff_datetime)"),
]

request_anchor = FeatureAnchor(name="request_features",
source=INPUT_CONTEXT,
features=features)

f_trip_time_distance = DerivedFeature(name="f_trip_time_distance",
feature_type=FLOAT,
input_features=[
f_trip_distance, f_trip_time_duration],
transform="f_trip_distance * f_trip_time_duration")

f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded",
feature_type=INT32,
input_features=[f_trip_time_duration],
transform="f_trip_time_duration % 10")

location_id = TypedKey(key_column="DOLocationID",
key_column_type=ValueType.INT32,
description="location id in NYC",
full_name="nyc_taxi.location_id")
agg_features = [Feature(name="f_location_avg_fare",
key=location_id,
feature_type=FLOAT,
transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)",
agg_func="AVG",
window="90d",
filter="fare_amount > 0"
)),
Feature(name="f_location_max_fare",
key=location_id,
feature_type=FLOAT,
transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)",
agg_func="MAX",
window="90d"))
]

agg_anchor = FeatureAnchor(name="aggregationFeatures",
source=sql_source,
features=agg_features)

client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[
f_trip_time_distance, f_trip_time_rounded])

return client


def _sql_query_source():
return SparkSqlSource(name="sparkSqlQuerySource", sql="SELECT * FROM green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss")


def _sql_table_source():
return SparkSqlSource(name="sparkSqlTableSource", table="green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss")

0 comments on commit 308321d

Please sign in to comment.