Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Windoze/spark sql source #839

Merged
merged 11 commits into from
Jan 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import scala.collection.JavaConverters._
new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"),
new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"),
new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"),
new JsonSubTypes.Type(value = classOf[SparkSqlLocation], name = "sparksql"),
new JsonSubTypes.Type(value = classOf[Snowflake], name = "snowflake"),
))
trait DataLocation {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package com.linkedin.feathr.offline.config.location

import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize
import com.linkedin.feathr.common.Header
import org.apache.spark.sql.{DataFrame, SparkSession}

@CaseClassDeserialize()
case class SparkSqlLocation(sql: Option[String] = None, table: Option[String] = None) extends DataLocation {
/**
* Backward Compatibility
* Many existing codes expect a simple path
*
* @return the `path` or `url` of the data source
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `path` from `SimplePath` only
*/
override def getPath: String = sql.getOrElse(table.getOrElse(""))

/**
* Backward Compatibility
*
* @return the `path` or `url` of the data source, wrapped in an List
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `paths` from `PathList` only
*/
override def getPathList: List[String] = List(getPath)

/**
* Load DataFrame from Spark session
*
* @param ss SparkSession
* @return
*/
override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = {
sql match {
case Some(sql) => {
ss.sql(sql)
}
case None => table match {
case Some(table) => {
ss.sqlContext.table(table)
}
case None => {
throw new IllegalArgumentException("`sql` and `table` parameter should not both be empty")
}
}
}
}

/**
* Write DataFrame to the location
*
* NOTE: SparkSqlLocation doesn't support writing, it cannot be used as data sink
*
* @param ss SparkSession
* @param df DataFrame to write
*/
override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = ???

/**
* Tell if this location is file based
*
* @return boolean
*/
override def isFileBasedLocation(): Boolean = false
}

Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.linkedin.feathr.offline.source.accessor

import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, PathList, SimplePath, Snowflake}
import com.linkedin.feathr.offline.config.location.{GenericLocation, Jdbc, PathList, SimplePath, Snowflake, SparkSqlLocation}
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.dataloader.{CaseInsensitiveGenericRecordWrapper, DataLoaderFactory}
import com.linkedin.feathr.offline.testfwk.TestFwkUtils
Expand Down Expand Up @@ -35,6 +35,7 @@ private[offline] class NonTimeBasedDataSourceAccessor(
case PathList(paths) => paths.map(fileLoaderFactory.create(_).loadDataFrame()).reduce((x, y) => x.fuzzyUnion(y))
case Jdbc(_, _, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case GenericLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case SparkSqlLocation(_, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case Snowflake(_, _, _, _) => source.location.loadDf(SparkSession.builder().getOrCreate())
case _ => fileLoaderFactory.createFromLocation(source.location).loadDataFrame()
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package com.linkedin.feathr.offline.config.location

import com.linkedin.feathr.offline.TestFeathr
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
import org.testng.Assert.assertEquals
import org.testng.annotations.{BeforeClass, Test}

import scala.collection.JavaConversions._

class TestSparkSqlLocation extends TestFeathr{
@BeforeClass
override def setup(): Unit = {
ss = TestFeathr.getOrCreateSparkSessionWithHive
super.setup()
}

@Test(description = "It should read a Spark SQL as the DataSource")
def testCreateWithFixedPath(): Unit = {
val schema = StructType(Array(
StructField("language", StringType, true),
StructField("users", IntegerType, true)
))
val rowData= Seq(Row("Java", 20000),
Row("Python", 100000),
Row("Scala", 3000))
val tempDf = ss.createDataFrame(rowData, schema)
tempDf.createTempView("test_spark_sql_table")

val loc = SparkSqlLocation(sql = Some("select * from test_spark_sql_table order by users asc"))
val dataSource = DataSource(loc, SourceFormatType.FIXED_PATH)
val accessor = DataSourceAccessor(ss, dataSource, None, None, failOnMissingPartition = false, dataPathHandlers=List())
val df = accessor.get()
val expectedRows = Array(
Row("Scala", 3000),
Row("Java", 20000),
Row("Python", 100000)
)
assertEquals(df.collect(), expectedRows)

ss.sqlContext.dropTempTable("test_spark_sql_table")
}
}
68 changes: 68 additions & 0 deletions feathr_project/feathr/definition/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,74 @@ def to_feature_config(self) -> str:
def to_argument(self):
raise TypeError("KafKaSource cannot be used as observation source")

class SparkSqlSource(Source):
def __init__(self, name: str, sql: Optional[str] = None, table: Optional[str] = None, preprocessing: Optional[Callable] = None, event_timestamp_column: Optional[str] = None, timestamp_format: Optional[str] = "epoch", registry_tags: Optional[Dict[str, str]] = None) -> None:
windoze marked this conversation as resolved.
Show resolved Hide resolved
""" SparkSqlSource can use either a sql query or a table name as the source for Feathr job.
name: name of the source
sql: sql query to use as the source, either sql or table must be specified
table: table name to use as the source, either sql or table must be specified
preprocessing (Optional[Callable]): A preprocessing python function that transforms the source data for further feature transformation.
event_timestamp_column (Optional[str]): The timestamp field of your record. As sliding window aggregation feature assume each record in the source data should have a timestamp column.
timestamp_format (Optional[str], optional): The format of the timestamp field. Defaults to "epoch". Possible values are:
- `epoch` (seconds since epoch), for example `1647737463`
- `epoch_millis` (milliseconds since epoch), for example `1647737517761`
- Any date formats supported by [SimpleDateFormat](https://docs.oracle.com/javase/8/docs/api/java/text/SimpleDateFormat.html).
registry_tags: A dict of (str, str) that you can pass to feature registry for better organization. For example, you can use {"deprecated": "true"} to indicate this source is deprecated, etc.
"""
super().__init__(name, event_timestamp_column,
windoze marked this conversation as resolved.
Show resolved Hide resolved
timestamp_format, registry_tags=registry_tags)
self.source_type = 'sparksql'
if sql is None and table is None:
raise ValueError("Either `sql` or `table` must be specified")
if sql is not None and table is not None:
raise ValueError("Only one of `sql` or `table` can be specified")
if sql is not None:
self.sql = sql
if table is not None:
self.table = table
self.preprocessing = preprocessing

def to_feature_config(self) -> str:
tm = Template("""
{{source.name}}: {
location: {
type: "sparksql"
{% if source.sql is defined %}
sql: "{{source.sql}}"
{% elif source.table is defined %}
table: "{{source.table}}"
{% endif %}
}
{% if source.event_timestamp_column is defined %}
timeWindowParameters: {
timestampColumn: "{{source.event_timestamp_column}}"
timestampColumnFormat: "{{source.timestamp_format}}"
}
{% endif %}
}
""")
msg = tm.render(source=self)
return msg

def get_required_properties(self):
return []

def to_dict(self) -> Dict[str, str]:
ret = self.options.copy()
ret["type"] = "sparksql"
if self.sql:
ret["sql"] = self.sql
elif self.table:
ret["table"] = self.table
return ret

def to_argument(self):
"""
One-line JSON string, used by job submitter
"""
return json.dumps(self.to_dict())


class GenericSource(Source):
"""
This class is corresponding to 'GenericLocation' in Feathr core, but only be used as Source.
Expand Down
159 changes: 159 additions & 0 deletions feathr_project/test/test_spark_sql_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import os
from datetime import datetime, timedelta
from pathlib import Path
from feathr import (BOOLEAN, FLOAT, INPUT_CONTEXT, INT32,
DerivedFeature, Feature, FeatureAnchor,
TypedKey, ValueType, WindowAggTransformation)

import pytest
from feathr import FeathrClient
from feathr import FeatureQuery
from feathr import ObservationSettings
from feathr import TypedKey
from feathr import ValueType
from feathr.definition.materialization_settings import BackfillTime, MaterializationSettings
from feathr.definition.sink import HdfsSink
from feathr.utils.job_utils import get_result_df
from test_utils.constants import Constants
from feathr.definition.source import SparkSqlSource


@pytest.mark.skipif(os.environ.get('SPARK_CONFIG__SPARK_CLUSTER') != "databricks",
reason="this test uses predefined Databricks table.")
def test_feathr_spark_sql_query_source():
test_workspace_dir = Path(
__file__).parent.resolve() / "test_user_workspace"
config_path = os.path.join(test_workspace_dir, "feathr_config.yaml")

_get_offline_features(config_path, _sql_query_source())
_get_offline_features(config_path, _sql_table_source())
_materialize_to_offline(config_path, _sql_query_source())


def _get_offline_features(config_path: str, sql_source: SparkSqlSource):
client: FeathrClient = _spark_sql_test_setup(config_path, sql_source)
location_id = TypedKey(key_column="DOLocationID",
key_column_type=ValueType.INT32,
description="location id in NYC",
full_name="nyc_taxi.location_id")

feature_query = FeatureQuery(
feature_list=["f_location_avg_fare"], key=location_id)
settings = ObservationSettings(
observation_path="wasbs://[email protected]/sample_data/green_tripdata_2020-04.csv",
event_timestamp_column="lpep_dropoff_datetime",
timestamp_format="yyyy-MM-dd HH:mm:ss")

now = datetime.now()
output_path = ''.join(['dbfs:/feathrazure_cijob_sparksql',
'_', str(now.minute), '_', str(now.second), ".avro"])

client.get_offline_features(observation_settings=settings,
feature_query=feature_query,
output_path=output_path)

# assuming the job can successfully run; otherwise it will throw exception
client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS)
return


def _materialize_to_offline(config_path: str, sql_source: SparkSqlSource):
client: FeathrClient = _spark_sql_test_setup(config_path, sql_source)
backfill_time = BackfillTime(start=datetime(
2020, 5, 20), end=datetime(2020, 5, 20), step=timedelta(days=1))

now = datetime.now()
if client.spark_runtime == 'databricks':
output_path = ''.join(['dbfs:/feathrazure_cijob_materialize_offline_sparksql',
'_', str(now.minute), '_', str(now.second), ""])
else:
output_path = ''.join(['abfss://[email protected]/demo_data/feathrazure_cijob_materialize_offline_sparksql',
'_', str(now.minute), '_', str(now.second), ""])
offline_sink = HdfsSink(output_path=output_path)
settings = MaterializationSettings("nycTaxiTable",
sinks=[offline_sink],
feature_names=[
"f_location_avg_fare", "f_location_max_fare"],
backfill_time=backfill_time)
client.materialize_features(settings)
# assuming the job can successfully run; otherwise it will throw exception
client.wait_job_to_finish(timeout_sec=Constants.SPARK_JOB_TIMEOUT_SECONDS)

# download result and just assert the returned result is not empty
# by default, it will write to a folder appended with date
res_df = get_result_df(
client, "avro", output_path + "/df0/daily/2020/05/20")
assert res_df.shape[0] > 0


def _spark_sql_test_setup(config_path: str, sql_source: SparkSqlSource):
client = FeathrClient(config_path=config_path)

f_trip_distance = Feature(name="f_trip_distance",
feature_type=FLOAT, transform="trip_distance")
f_trip_time_duration = Feature(name="f_trip_time_duration",
feature_type=INT32,
transform="(to_unix_timestamp(lpep_dropoff_datetime) - to_unix_timestamp(lpep_pickup_datetime))/60")

features = [
f_trip_distance,
f_trip_time_duration,
Feature(name="f_is_long_trip_distance",
feature_type=BOOLEAN,
transform="cast_float(trip_distance)>30"),
Feature(name="f_day_of_week",
feature_type=INT32,
transform="dayofweek(lpep_dropoff_datetime)"),
]

request_anchor = FeatureAnchor(name="request_features",
source=INPUT_CONTEXT,
features=features)

f_trip_time_distance = DerivedFeature(name="f_trip_time_distance",
feature_type=FLOAT,
input_features=[
f_trip_distance, f_trip_time_duration],
transform="f_trip_distance * f_trip_time_duration")

f_trip_time_rounded = DerivedFeature(name="f_trip_time_rounded",
feature_type=INT32,
input_features=[f_trip_time_duration],
transform="f_trip_time_duration % 10")

location_id = TypedKey(key_column="DOLocationID",
key_column_type=ValueType.INT32,
description="location id in NYC",
full_name="nyc_taxi.location_id")
agg_features = [Feature(name="f_location_avg_fare",
key=location_id,
feature_type=FLOAT,
transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)",
agg_func="AVG",
window="90d",
filter="fare_amount > 0"
)),
Feature(name="f_location_max_fare",
key=location_id,
feature_type=FLOAT,
transform=WindowAggTransformation(agg_expr="cast_float(fare_amount)",
agg_func="MAX",
window="90d"))
]

agg_anchor = FeatureAnchor(name="aggregationFeatures",
source=sql_source,
features=agg_features)

client.build_features(anchor_list=[agg_anchor, request_anchor], derived_feature_list=[
f_trip_time_distance, f_trip_time_rounded])

return client


def _sql_query_source():
return SparkSqlSource(name="sparkSqlQuerySource", sql="SELECT * FROM green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss")


def _sql_table_source():
return SparkSqlSource(name="sparkSqlTableSource", table="green_tripdata_2020_04_with_index", event_timestamp_column="lpep_dropoff_datetime", timestamp_format="yyyy-MM-dd HH:mm:ss")