Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jaymo001 committed Apr 19, 2022
1 parent c2ca82b commit 7ba5f56
Show file tree
Hide file tree
Showing 17 changed files with 195 additions and 125 deletions.
45 changes: 44 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ batch_source = HdfsSource(
timestamp_format="yyyy-MM-dd HH:mm:ss") # Supports various fromats inculding epoch
```

### Beyond Features on Raw Data Sources - Derived Features
### Define features on top of other features - Derived Features

```python
# Compute a new feature(a.k.a. derived feature) on top of an existing feature
Expand All @@ -177,6 +177,49 @@ user_item_similarity = DerivedFeature(name="user_item_similarity",
transform="cosine_similarity(user_embedding, item_embedding)")
```

## Define Streaming Features

```python
# Define input data schema
schema = AvroJsonSchema(schemaStr="""
{
"type": "record",
"name": "DriverTrips",
"fields": [
{"name": "driver_id", "type": "long"},
{"name": "trips_today", "type": "int"},
{
"name": "datetime",
"type": {"type": "long", "logicalType": "timestamp-micros"}
}
]
}
""")
stream_source = KafKaSource(name="kafkaStreamingSource",
kafkaConfig=KafkaConfig(brokers=["feathrazureci.servicebus.windows.net:9093"],
topics=["feathrcieventhub"],
schema=schema)
)

driver_id = TypedKey(key_column="driver_id",
key_column_type=ValueType.INT64,
description="driver id",
full_name="nyc driver id")

kafkaAnchor = FeatureAnchor(name="kafkaAnchor",
source=stream_source,
features=[Feature(name="f_modified_streaming_count",
feature_type=INT32,
transform="trips_today + 1",
key=driver_id),
Feature(name="f_modified_streaming_count2",
feature_type=INT32,
transform="trips_today + 2",
key=driver_id)]
)

```

## Running Feathr Examples

Follow the [quick start Jupyter Notebook](./feathr_project/feathrcli/data/feathr_user_workspace/nyc_driver_demo.ipynb) to try it out. There is also a companion [quick start guide](./docs/quickstart.md) containing a bit more explanation on the notebook.
Expand Down
2 changes: 1 addition & 1 deletion docs/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ batch_source = HdfsSource(
timestamp_format="yyyy-MM-dd HH:mm:ss") # Supports various fromats inculding epoch
```

## Beyond Features on Raw Data Sources - Derived Features
## Define features on top of other features - Derived Features

```python
# Compute a new feature(a.k.a. derived feature) on top of an existing feature
Expand Down
15 changes: 15 additions & 0 deletions docs/concepts/feature-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ kafkaAnchor = FeatureAnchor(name="kafkaAnchor",
)

```
Note that only Feathr ExpressionTransformation is allowed in streaming anchor at the moment.
Other transformations support are in the roadmap.

3. Start streaming job

```python
Expand All @@ -102,4 +105,16 @@ settings = MaterializationSettings(name="kafkaSampleDemo",
feature_names=['f_modified_streaming_count']
)
client.materialize_features(settings) # Will streaming for 10 seconds since streamingTimeoutMs is 10000
```
4. Fetch streaming feature values

```python

res = client.get_online_features('kafkaSampleDemoFeature', '1',
['f_modified_streaming_count'])
# Get features for multiple feature keys
res = client.multi_get_online_features('kafkaSampleDemoFeature',
['1', '2'],
['f_modified_streaming_count'])

```
3 changes: 3 additions & 0 deletions feathr_project/feathr/_envvariableutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def get_environment_variable_with_default(self, *args):
for arg in args:
yaml_layer = yaml_layer[arg]
return yaml_layer
except KeyError as exc:
logger.info(exc)
return ""
except yaml.YAMLError as exc:
logger.info(exc)

Expand Down
2 changes: 1 addition & 1 deletion feathr_project/feathr/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def __init__(self, config_path:str = "./feathr_config.yaml", local_workspace_dir

# Kafka configs
self.kafka_endpoint = envutils.get_environment_variable_with_default(
'offline_store', 'kafka', 'kafka_endpoint')
'streaming', 'kafka', 'kafka_endpoint')

# spark configs
self.output_num_parts = envutils.get_environment_variable_with_default(
Expand Down
2 changes: 1 addition & 1 deletion feathr_project/feathr/sink.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def to_write_config(self) -> str:
{% if source.streaming %}
streaming: true
{% endif %}
{% if source.streamingTimeoutMs is defined %}
{% if source.streamingTimeoutMs %}
timeoutMs: {{source.streamingTimeoutMs}}
{% endif %}
}
Expand Down
6 changes: 6 additions & 0 deletions feathr_project/feathr/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,12 @@ def __str__(self):


class KafkaConfig:
"""Kafka config for a streaming source
Attributes:
brokers: broker/server address
topics: Kafka topics
schema: Kafka message schema
"""
def __init__(self, brokers: List[str], topics: List[str], schema: SourceSchema):
self.brokers = brokers
self.topics = topics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ offline_store:
jdbc_table: "feathrtesttable"

# reading from streaming source
streaming:
kafka:
kafka_endpoint: 'sb://feathrazureci.servicebus.windows.net/'

Expand Down
1 change: 1 addition & 0 deletions feathr_project/test/test_user_workspace/feathr_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ offline_store:
role: "ACCOUNTADMIN"

# reading from streaming source
streaming:
kafka:
kafka_endpoint: 'sb://feathrazureci.servicebus.windows.net/'

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,23 @@ import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.codehaus.jackson.annotate.JsonProperty


/**
* Kafka source config.
* Example:
* kafkaStreamingSource: {
type: KAFKA
config: {
brokers: ["feathrazureci.servicebus.windows.net:9093"]
topics: [feathrcieventhub]
schema: {
type = "avro"
avroJson:"......"
}
}
}
*
*
*/
@CaseClassDeserialize()
case class KafkaSchema(@JsonProperty("type") `type`: String,
@JsonProperty("avroJson") avroJson: String)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import com.databricks.spark.avro.SchemaConverters
import com.google.common.collect.Lists
import com.linkedin.feathr.common.JoiningFeatureParams
import com.linkedin.feathr.offline.config.location.KafkaEndpoint
import com.linkedin.feathr.offline.generation.outputProcessor.PushToRedisOutputProcessor.TABLE_PARAM_CONFIG_NAME
import com.linkedin.feathr.offline.generation.outputProcessor.RedisOutputUtils
import com.linkedin.feathr.offline.job.FeatureTransformation.getFeatureJoinKey
import com.linkedin.feathr.offline.job.{FeatureGenSpec, FeatureTransformation}
Expand All @@ -26,11 +27,21 @@ import java.nio.ByteBuffer
import scala.collection.JavaConverters._
import scala.collection.convert.wrapAll._

/**
* Class to ingest streaming features
*/
class StreamingFeatureGenerator {
@transient val anchorToDataFrameMapper = new AnchorToDataSourceMapper()

/**
* Ingest streaming features
* @param ss spark session
* @param featureGenSpec feature generation config specification
* @param featureGroups all features defined in the system
* @param keyTaggedFeatures streaming features to ingest/generate
*/
def generateFeatures(ss: SparkSession, featureGenSpec: FeatureGenSpec, featureGroups: FeatureGroups,
keyTaggedFeatures: Seq[JoiningFeatureParams]) = {
keyTaggedFeatures: Seq[JoiningFeatureParams]): Unit = {
val anchors = keyTaggedFeatures.map(streamingFeature => {
featureGroups.allAnchoredFeatures.get(streamingFeature.featureName).get
})
Expand All @@ -43,8 +54,8 @@ class StreamingFeatureGenerator {
outputConfig.getParams.getNumber("timeoutMs").longValue()
}
// Load the raw streaming source data
val anchorDFRDDMap = anchorToDataFrameMapper.getAnchorDFMapForGen(ss, anchors, None, false, true)
anchorDFRDDMap.par.map { case (anchor, accessor) => {
val anchorDfRDDMap = anchorToDataFrameMapper.getAnchorDFMapForGen(ss, anchors, None, false, true)
anchorDfRDDMap.par.map { case (anchor, dfAccessor) => {
val schemaStr = anchor.source.location.asInstanceOf[KafkaEndpoint].schema.avroJson
val schemaStruct = SchemaConverters.toSqlType(Schema.parse(schemaStr)).dataType.asInstanceOf[StructType]
val rowForRecord = (input: Any) => {
Expand All @@ -55,14 +66,21 @@ class StreamingFeatureGenerator {
val reader = new GenericDatumReader[GenericRecord](avroSchema)
val record = reader.read(null, decoder)
for (field <- record.getSchema.getFields) {
var value = record.get(field.name)
var fieldType = field.schema.getType
if (fieldType.equals(Type.UNION)) fieldType = field.schema.getTypes.get(1).getType
// Avro returns Utf8s for strings, which Spark SQL doesn't know how to use
if (fieldType.equals(Type.STRING) && value != null) value = value.toString
// Avro returns binary as a ByteBuffer, but Spark SQL wants a byte[]
if (fieldType.equals(Type.BYTES) && value != null) value = value.asInstanceOf[ByteBuffer].array
values.add(value)
val fieldType = if (field.schema.equals(Type.UNION)) {
field.schema.getTypes.get(1).getType
} else {
field.schema.getType
}
val fieldValue = Option(record.get(field.name)) match {
case Some(value) if fieldType.equals(Type.STRING) =>
// Avro returns Utf8s for strings, which Spark SQL doesn't know how to use
value.toString
case Some(value) if fieldType.equals(Type.BYTES) =>
// Avro returns binary as a ByteBuffer, but Spark SQL wants a byte[]
value.asInstanceOf[ByteBuffer].array
case _ => record.get(field.name)
}
values.add(fieldValue)
}

new CustomGenericRowWithSchema(
Expand All @@ -72,7 +90,7 @@ class StreamingFeatureGenerator {
val convertUDF = udf(rowForRecord)

// Streaming processing each source
accessor.get().writeStream
dfAccessor.get().writeStream
.outputMode(OutputMode.Update)
.foreachBatch{ (batchDF: DataFrame, batchId: Long) =>
// Convert each batch dataframe from the kafka built-in schema(which always has 'value' field) to user provided schema
Expand All @@ -97,8 +115,7 @@ class StreamingFeatureGenerator {
val cleanedDF = transformedResult.df.select(selectedColumns.head, selectedColumns.tail:_*)
val keyColumnNames = FeatureTransformation.getStandardizedKeyNames(outputJoinKeyColumnNames.size)
val resultFDS: DataFrame = PostGenPruner().standardizeColumns(outputJoinKeyColumnNames, keyColumnNames, cleanedDF)
val tableParam = "table_name"
val tableName = outputConfig.getParams.getString(tableParam)
val tableName = outputConfig.getParams.getString(TABLE_PARAM_CONFIG_NAME)
val allFeatureCols = resultFDS.columns.diff(keyColumnNames).toSet
RedisOutputUtils.writeToRedis(ss, resultFDS, tableName, keyColumnNames, allFeatureCols, SaveMode.Append)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.linkedin.feathr.offline.generation.outputProcessor
import com.linkedin.feathr.common.Header
import com.linkedin.feathr.common.configObj.generation.OutputProcessorConfig
import com.linkedin.feathr.offline.generation.FeatureGenUtils
import com.linkedin.feathr.offline.generation.outputProcessor.PushToRedisOutputProcessor.TABLE_PARAM_CONFIG_NAME
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}

/**
Expand All @@ -22,10 +23,15 @@ private[offline] class PushToRedisOutputProcessor(config: OutputProcessorConfig,
*/
override def processSingle(ss: SparkSession, df: DataFrame, header: Header, parentPath: String): (DataFrame, Header) = {
val keyColumns = FeatureGenUtils.getKeyColumnsFromHeader(header)
val tableParam = "table_name"
val tableName = config.getParams.getString(tableParam)

val tableName = config.getParams.getString(TABLE_PARAM_CONFIG_NAME)
val allFeatureCols = header.featureInfoMap.map(x => (x._2.columnName)).toSet
RedisOutputUtils.writeToRedis(ss, df, tableName, keyColumns, allFeatureCols, SaveMode.Overwrite)
(df, header)
}
}

object PushToRedisOutputProcessor {
// Parameter name in Redis output processor config for table name
val TABLE_PARAM_CONFIG_NAME = "table_name"
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ private[offline] class StreamDataSourceAccessor(
*/
override def get(): DataFrame = {
if (!source.location.isInstanceOf[KafkaEndpoint]) {
throw new FeathrException(s"${source.location} is not a Kakfa streaming source location.")
throw new FeathrException(s"${source.location} is not a Kakfa streaming source location." +
s" Only Kafka source is supported right now.")
}
dataLoaderFactory.createFromLocation(source.location).loadDataFrame()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,21 @@ class KafkaDataLoader(ss: SparkSession, input: KafkaEndpoint) extends StreamData
val kafkaUsername = KafkaResourceInfoSetter.USERNAME
val sharedAccessKey = KafkaResourceInfoSetter.SHARED_ACCESS_KEY

// Construct authentication string for Kafka on Azure
private def getKafkaAuth(ss: SparkSession): String = {
// If user set password, then we use password to auth
ss.conf.getOption(kafkaEndpoint) match {
case Some(_) =>
val accessKeyName = ss.conf.get(sharedAccessKeyName)
val endpoint = ss.conf.get(kafkaEndpoint)
val username = ss.conf.get(kafkaUsername)
val accessKey = ss.conf.get(sharedAccessKey)
val EH_SASL = "org.apache.kafka.common.security.plain.PlainLoginModule required username=\""+username+"\"" +
" password=\"Endpoint="+endpoint+";SharedAccessKeyName="+accessKeyName+";SharedAccessKey="+accessKey+"\";"
val accessKeyName = ss.conf.getOption(sharedAccessKeyName)
val endpoint = ss.conf.getOption(kafkaEndpoint)
val username = ss.conf.getOption(kafkaUsername)
val accessKey = ss.conf.getOption(sharedAccessKey)
if (!accessKeyName.isDefined || !endpoint.isDefined || !username.isDefined || !accessKey.isDefined) {
throw new RuntimeException(s"Invalid Kafka authentication! ${kafkaEndpoint}, ${sharedAccessKeyName}," +
s" ${kafkaUsername} and ${sharedAccessKey} must be set in Spark conf.")
}
val EH_SASL = "org.apache.kafka.common.security.plain.PlainLoginModule required username=\""+username.get+"\"" +
" password=\"Endpoint="+endpoint.get+";SharedAccessKeyName="+accessKeyName.get+";SharedAccessKey="+accessKey.get+"\";"
EH_SASL
case _ => {
throw new RuntimeException(s"Invalid Kafka authentication! ${kafkaEndpoint} is not set in Spark conf.")
Expand All @@ -43,7 +48,8 @@ class KafkaDataLoader(ss: SparkSession, input: KafkaEndpoint) extends StreamData
.options(kafkaOptions)
.option("kafka.group.id", UUID.randomUUID().toString)
.option("kafka.sasl.jaas.config", EH_SASL)
.option("kafka.sasl.mechanism", "PLAIN")
// These are default settings, we can expose them in the future if required
.option("kafka.sasl.mechanism", "PLAIN")
.option("kafka.security.protocol", "SASL_SSL")
.option("kafka.request.timeout.ms", "60000")
.option("kafka.session.timeout.ms", "60000")
Expand All @@ -56,11 +62,12 @@ class KafkaDataLoader(ss: SparkSession, input: KafkaEndpoint) extends StreamData
* @return an dataframe
*/
override def loadDataFrame(): DataFrame = {
val TOPIC = input.topics.mkString(",")
val BOOTSTRAP_SERVERS = input.brokers.mkString(",")
// Spark conf require ',' delimited strings for multi-value config
val topic = input.topics.mkString(",")
val bootstrapServers= input.brokers.mkString(",")
getDFReader(Map())
.option("subscribe", TOPIC)
.option("kafka.bootstrap.servers", BOOTSTRAP_SERVERS)
.option("subscribe", topic)
.option("kafka.bootstrap.servers", bootstrapServers)
.load()
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,16 @@ class CustomGenericRowWithSchema(values: Array[Any], inputSchema: StructType)
class GenericRowWithSchemaUDT extends UserDefinedType[CustomGenericRowWithSchema] {
override def sqlType: DataType = org.apache.spark.sql.types.BinaryType
override def serialize(obj: CustomGenericRowWithSchema): Any = {
val bos = new ByteArrayOutputStream()
val oos = new ObjectOutputStream(bos)
oos.writeObject(obj)
bos.toByteArray
val byteStream = new ByteArrayOutputStream()
val objectStream = new ObjectOutputStream(byteStream)
objectStream.writeObject(obj)
byteStream.toByteArray
}
override def deserialize(datum: Any): CustomGenericRowWithSchema = {
val bis = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
val ois = new ObjectInputStream(bis)
val obj = ois.readObject()
obj.asInstanceOf[CustomGenericRowWithSchema]
val byteStream = new ByteArrayInputStream(datum.asInstanceOf[Array[Byte]])
val objectStream = new ObjectInputStream(byteStream)
val row = objectStream.readObject()
row.asInstanceOf[CustomGenericRowWithSchema]
}
override def userClass: Class[CustomGenericRowWithSchema] = classOf[CustomGenericRowWithSchema]
}
Expand Down
Loading

0 comments on commit 7ba5f56

Please sign in to comment.