diff --git a/feathr-impl/src/main/java/com/linkedin/feathr/common/FeatureVariableResolver.java b/feathr-impl/src/main/java/com/linkedin/feathr/common/FeatureVariableResolver.java index 3dd6718f9..bb84152ca 100644 --- a/feathr-impl/src/main/java/com/linkedin/feathr/common/FeatureVariableResolver.java +++ b/feathr-impl/src/main/java/com/linkedin/feathr/common/FeatureVariableResolver.java @@ -4,8 +4,12 @@ import com.linkedin.feathr.common.tensor.TensorIterator; import com.linkedin.feathr.common.types.ValueType; import com.linkedin.feathr.common.util.CoercionUtils; +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext; +import org.mvel2.DataConversion; import org.mvel2.integration.impl.SimpleValueResolver; +import java.util.Optional; + /** * FeatureVariableResolver takes a FeatureValue object for member variable during MVEL expression evaluation, @@ -13,10 +17,11 @@ */ public class FeatureVariableResolver extends SimpleValueResolver { private FeatureValue _featureValue; - - public FeatureVariableResolver(FeatureValue featureValue) { + private Optional _mvelContext = Optional.empty(); + public FeatureVariableResolver(FeatureValue featureValue, Optional mvelContext) { super(featureValue); _featureValue = featureValue; + _mvelContext = mvelContext; } @Override @@ -25,21 +30,27 @@ public Object getValue() { return null; } + Object fv = null; switch (_featureValue.getFeatureType().getBasicType()) { case NUMERIC: - return _featureValue.getAsNumeric(); + fv = _featureValue.getAsNumeric(); break; case TERM_VECTOR: - return getValueFromTermVector(); + fv = getValueFromTermVector(); break; case BOOLEAN: case CATEGORICAL: case CATEGORICAL_SET: case DENSE_VECTOR: case TENSOR: - return getValueFromTensor(); - + fv = getValueFromTensor(); break; default: - throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType()); + throw new IllegalArgumentException("Unexpected feature type: " + _featureValue.getFeatureType().getBasicType()); + } + // If there is any registered FeatureValue handler that can handle this feature value, return the converted value per request. + if (_mvelContext.isPresent() && _mvelContext.get().canConvertFromAny(fv)) { + return _mvelContext.get().convertFromAny(fv).head(); + } else { + return fv; } } diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala index b1f75d662..79518fd10 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/PostTransformationUtil.scala @@ -130,7 +130,7 @@ private[offline] object PostTransformationUtil { featureType: FeatureTypes, mvelContext: Option[FeathrExpressionExecutionContext]): Try[FeatureValue] = Try { val args = Map(featureName -> Some(featureValue)) - val variableResolverFactory = new FeatureVariableResolverFactory(args) + val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext) val transformedValue = MvelContext.executeExpressionWithPluginSupportWithFactory(compiledExpression, featureValue, variableResolverFactory, mvelContext.orNull) CoercionUtilsScala.coerceToFeatureValue(transformedValue, featureType) } diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala index 42f09ad21..e8f7d4196 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction.scala @@ -42,7 +42,7 @@ private[offline] class MvelFeatureDerivationFunction( override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = { val argMap = (parameterNames zip inputs).toMap - val variableResolverFactory = new FeatureVariableResolverFactory(argMap) + val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext) MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match { case Some(value) => diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction1.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction1.scala index 2d5e30fb8..7e403089e 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction1.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/MvelFeatureDerivationFunction1.scala @@ -39,7 +39,7 @@ private[offline] class MvelFeatureDerivationFunction1( override def getFeatures(inputs: Seq[Option[common.FeatureValue]]): Seq[Option[common.FeatureValue]] = { val argMap = (parameterNames zip inputs).toMap - val variableResolverFactory = new FeatureVariableResolverFactory(argMap) + val variableResolverFactory = new FeatureVariableResolverFactory(argMap, mvelContext) MvelUtils.executeExpression(compiledExpression, null, variableResolverFactory, featureName, mvelContext) match { case Some(value) => diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala index 7ed9ad0a8..8d9695e2b 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/derived/functions/SimpleMvelDerivationFunction.scala @@ -44,7 +44,7 @@ private[offline] class SimpleMvelDerivationFunction(expression: String, featureN MvelContext.ensureInitialized() // In order to prevent MVEL from barfing if a feature is null, we use a custom variable resolver that understands `Option` - val variableResolverFactory = new FeatureVariableResolverFactory(args) + val variableResolverFactory = new FeatureVariableResolverFactory(args, mvelContext) if (TestFwkUtils.IS_DEBUGGER_ENABLED) { while(TestFwkUtils.DERIVED_FEATURE_COUNTER > 0) { diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/FeatureVariableResolverFactory.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/FeatureVariableResolverFactory.scala index a81e04d11..8f05ac3f0 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/FeatureVariableResolverFactory.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/FeatureVariableResolverFactory.scala @@ -1,13 +1,16 @@ package com.linkedin.feathr.offline.mvel import com.linkedin.feathr.common.{FeatureValue, FeatureVariableResolver} +import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import org.mvel2.integration.VariableResolver import org.mvel2.integration.impl.BaseVariableResolverFactory +import java.util.Optional import scala.collection.JavaConverters._ -private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]]) extends BaseVariableResolverFactory { - variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull)).asInstanceOf[Map[String, VariableResolver]].asJava +private[offline] class FeatureVariableResolverFactory(features: Map[String, Option[FeatureValue]], mvelContext: Option[FeathrExpressionExecutionContext]) extends BaseVariableResolverFactory { + + variableResolvers = features.mapValues(x => new FeatureVariableResolver(x.orNull, Optional.ofNullable(mvelContext.orNull))).asInstanceOf[Map[String, VariableResolver]].asJava override def isTarget(name: String): Boolean = features.contains(name) diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala index ba63a654c..91d341d5f 100644 --- a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeathrExpressionExecutionContext.scala @@ -70,6 +70,7 @@ class FeathrExpressionExecutionContext extends Serializable { */ def canConvert(toType: Class[_], convertFrom: Class[_]): Boolean = { if (isAssignableFrom(toType, convertFrom)) return true + if (isAssignableFrom(classOf[FeatureValueWrapper[toType.type]], convertFrom)) return true if (converters.value.contains(toType.getCanonicalName)) { converters.value.get(toType.getCanonicalName).get.canConvertFrom(toNonPrimitiveType(convertFrom)) } else if (toType.isArray && canConvert(toType.getComponentType, convertFrom)) { @@ -79,6 +80,28 @@ class FeathrExpressionExecutionContext extends Serializable { } } + /** + * Check if there is registered converters that can handle the conversion. + * @param inputValue input value to check + * @return whether it can be converted or not + */ + def canConvertFromAny(inputValue: AnyRef): Boolean = { + val result = converters.value.filter(converter => converter._2.canConvertFrom(inputValue.getClass)) + result.nonEmpty + } + + /** + * Convert the input Check if there is registered converters that can handle the conversion. + * @param inputValue input value to convert + * @return return all converted values produced by registered converters + */ + def convertFromAny(inputValue: AnyRef): List[AnyRef] = { + converters.value.collect { + case converter if converter._2.canConvertFrom(inputValue.getClass) => + converter._2.convertFrom(inputValue) + }.toList + } + /** * Convert the input to output type using the registered converters * @param in value to be converted @@ -88,6 +111,9 @@ class FeathrExpressionExecutionContext extends Serializable { */ def convert[T](in: Any, toType: Class[T]): T = { if ((toType eq in.getClass) || toType.isAssignableFrom(in.getClass)) return in.asInstanceOf[T] + if (isAssignableFrom(classOf[FeatureValueWrapper[toType.type]], in.getClass)) { + return in.asInstanceOf[FeatureValueWrapper[_]].getFeatureValue().asInstanceOf[T] + } val converter = if (converters.value != null) { converters.value.get(toType.getCanonicalName).get } else { diff --git a/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeatureValueWrapper.scala b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeatureValueWrapper.scala new file mode 100644 index 000000000..363763ca9 --- /dev/null +++ b/feathr-impl/src/main/scala/com/linkedin/feathr/offline/mvel/plugins/FeatureValueWrapper.scala @@ -0,0 +1,10 @@ +package com.linkedin.feathr.offline.mvel.plugins + +/** + * Trait that wraps a Frame or Feathr FeatureValue + * @tparam T FeatureValue type to be wrapped + */ +trait FeatureValueWrapper[T] { + // Get the wrapped feature value + def getFeatureValue(): T +} \ No newline at end of file diff --git a/feathr-impl/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValue.java b/feathr-impl/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValue.java index d640020c2..50a831259 100644 --- a/feathr-impl/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValue.java +++ b/feathr-impl/src/test/java/com/linkedin/feathr/offline/plugins/AlienFeatureValue.java @@ -10,7 +10,10 @@ private AlienFeatureValue(Float floatValue, String stringValue) { this.floatValue = floatValue; this.stringValue = stringValue; } - + public AlienFeatureValue() { + this.floatValue = null; + this.stringValue = null; + } public static AlienFeatureValue fromFloat(float floatValue) { return new AlienFeatureValue(floatValue, null); } diff --git a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala index c93774683..071db5034 100644 --- a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala +++ b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/AnchoredFeaturesIntegTest.scala @@ -553,298 +553,6 @@ class AnchoredFeaturesIntegTest extends FeathrIntegTest { setFeathrJobParam(ADD_DEFAULT_COL_FOR_MISSING_DATA, "false") } - /* - * Test features with fdsExtract. - */ - @Test - def testFeaturesWithFdsExtract: Unit = { - val df = runLocalFeatureJoinForTest( - joinConfigAsString = - """ - | features: { - | key: a_id - | featureList: ["platoScore", "maxLoginScore". "profileHasPicture", "ipDataCountryCode"] - | } - """.stripMargin, - featureDefAsString = - """ - | anchors: { - | anchor1: { - | source: "anchorAndDerivations/nullValueSource.avro.json" - | key.sqlExpr: mId - | features: { - | featureWithNull { - | def.sqlExpr: "FDSExtract(coalesce(denseValue, ARRAY(ARRAY(\"aa\", \"bb\", \"cc\", \"dd\", \"ee\"), ARRAY(\"UNK\", \"UNK\", \"UNK\", \"UNK\", \"UNK\")) ))" - | type:{ - | type: TENSOR - | tensorCategory: DENSE - | shape: [2,5] - | dimensionType: [INT, INT] - | valType: STRING - | } - | } - | } - | } - | - | platoFlatFeatureVector: { - | source:"/tmp/fraud/jobs-fraud-model/atoscoresandfeaturesAvroPCV2" - | key.sqlExpr: ["memberId", "substring(date, 0, 10)"] - | features: { - | platoScore: { - | def.sqlExpr: "score", - | default: 0.0, - | type: "NUMERIC" - | } - | } - | } - | fraudJobsRatioPerIPFeatures: { - | source: "/jobs/fraud/zirannia/tf_home_dir/jobs-fraud-model/fraudJobsRatioPerIPFeature" - | key.sqlExpr: ["jobPostingIP", "date"] - | features: { - | fraudJobsIPFeaturesFedexNumJobs: { - | def.sqlExpr: "numJobs", - | default: -1.0, - | type: "NUMERIC" - | }, - | fraudJobsIPFeaturesFedexNumFraudJobs: { - | def.sqlExpr: "numFraudJobs", - | default: -1.0, - | type: "NUMERIC" - | }, - | fraudJobsIPFeaturesFedexFraudJobsRatio: { - | def.sqlExpr: "fraudJobsRatio" - | default: -1.0, - | type: "NUMERIC" - | } - | } - | } - | - | paramsFeatures: { - | source: "/jobs/fraud/zirannia/tf_home_dir/jobs-fraud-model/paramsFeatures" - | key.sqlExpr: ["jobId", "memberId", "substring(date, 0, 10)"] - | features: { - | maxLoginScore: { - | def.sqlExpr: "maxLoginScore", - | default: -1.0, - | type: "NUMERIC" - | }, - | registrationScore: { - | def.sqlExpr: "registrationScore", - | default: -1.0, - | type: "NUMERIC" - | }, - | profileHasPicture: { - | def.sqlExpr: "profileHasPicture" - | default: true, - | type: "BOOLEAN" - | }, - | accountageindays: { - | def.sqlExpr: "accountAgeInDays", - | default: 0.0, - | type: "NUMERIC" - | }, - | ipabusescore: { - | def.sqlExpr: "ipAbuseScore", - | default: 0.0, - | type: "NUMERIC" - | }, - | countryabusescore: { - | def.sqlExpr: "countryAbuseScore", - | default: 0.0, - | type: "NUMERIC" - | }, - | passwordResetFailureCountByIP: { - | def.sqlExpr: "passwordResetFailureCountByIP", - | default: 0.0, - | type: "NUMERIC" - | }, - | passwordResetSuccessCountByIP: { - | def.sqlExpr: "passwordResetSuccessCountByIP", - | default: 0.0, - | type: "NUMERIC" - | }, - | isEmailDomainReputationAbusive: { - | def.sqlExpr: "isEmailDomainReputationAbusive", - | default: false, - | type: "BOOLEAN" - | } - | } - | } - | - | restOfParamsFeatures: { - | source: "/jobs/fraud/zirannia/tf_home_dir/jobs-fraud-model/restOfParamFeatures" - | key.sqlExpr: ["job_id", "memberId", "substring(date, 0, 10)"] - | features: { - | asnIsAbusive: { - | def.sqlExpr: "asnIsAbusive", - | default: false, - | type: "BOOLEAN" - | }, - | jobAndMemberCountryMatch: { - | def.sqlExpr: "jobAndMemberCountryMatch", - | default: true, - | type: "BOOLEAN" - | }, - | companyCreationTime: { - | def.sqlExpr: "companyCreationTime", - | default: -1.0, - | type: "NUMERIC" - | }, - | asnIsOwnedByHostingService: { - | def.sqlExpr: "asnIsOwnedByHostingService", - | default: false, - | type: "BOOLEAN" - | }, - | companyFollowerCount: { - | def.sqlExpr: "companyFollowerCount" - | default: 0.0, - | type: "NUMERIC" - | }, - | countryMismatchGoodMemberCountWithSpecifiedAge: { - | def.sqlExpr: "countryMismatchGoodMemberCountWithSpecifiedAge", - | default: 0.0, - | type: "NUMERIC" - | }, - | countryMismatchRestrictedMemberCountWithSpecifiedAge: { - | def.sqlExpr: "countryMismatchRestrictedMemberCountWithSpecifiedAge", - | default: 0.0, - | type: "NUMERIC" - | }, - | dfpScore: { - | def.sqlExpr: "dfpScore", - | default: 0.0, - | type: "NUMERIC" - | }, - | ipIsOwnedByHostingService: { - | def.sqlExpr: "ipIsOwnedByHostingService", - | default: false, - | type: "BOOLEAN" - | }, - | isEmailDomainReputationCorpOwned: { - | def.sqlExpr: "isEmailDomainReputationCorpOwned", - | default: false, - | type: "BOOLEAN" - | }, - | numDaysActive: { - | def.sqlExpr: "numDaysActive", - | default: 0.0, - | type: "NUMERIC" - | }, - | orgIsAbusive: { - | def.sqlExpr: "orgIsAbusive", - | default: false, - | type: "BOOLEAN" - | }, - | orgIsOwnedByHostingService: { - | def.sqlExpr: "orgIsOwnedByHostingService", - | default: false, - | type: "BOOLEAN" - | }, - | postingMemberConnectionCount: { - | def.sqlExpr: "postingMemberConnectionCount", - | default: 0.0, - | type: "NUMERIC" - | }, - | useragentabusescore: { - | def.sqlExpr: "useragentabusescore", - | default: 0.0, - | type: "NUMERIC" - | }, - | jobPostingEmailDomain: { - | def.sqlExpr: "jobPostingEmailDomain", - | default: "", - | type: "CATEGORICAL" - | }, - | ipDataCountryCode: { - | def.sqlExpr: "ipDataCountryCode", - | default: "", - | type: "CATEGORICAL" - | }, - | jobPostCountryCode: { - | def.sqlExpr: "jobPostCountryCode", - | default: "", - | type: "CATEGORICAL" - | }, - | authorCountryCode: { - | def.sqlExpr: "authorCountryCode", - | default: "", - | type: "CATEGORICAL" - | }, - | wasWarmRegistration: { - | def.sqlExpr: "wasWarmRegistration", - | default: true, - | type: "BOOLEAN" - | }, - | emailDomainReputationRestrictionRatio: { - | def.sqlExpr: "emailDomainReputationRestrictionRatio", - | default: 0.0, - | type: "NUMERIC" - | } - | } - | } - | - | joinIpFeatures: { - | source: "dalids:///u_secaggs.joinipfeatures_datepartitioned" - | key.sqlExpr: ["memberid", "datepartition"] - | features: { - | joinIpFeaturesFedexFractionRestrictedReg: { - | def.sqlExpr: "fractionrestrictedreg" - | default: -1.0, - | type: "NUMERIC" - | } - | } - | } - | - | SafireFeatures: { - | source: "/tmp/fraud/jobs-fraud-model/safirescoresandfeaturesAvroPCV2" - | key.sqlExpr: ["memberId", "substring(lastActivityDate, 0, 10)"] - | features: { - | safireScore: { - | def.sqlExpr: "versionInfo.finalScore", - | default: 0.0, - | type: "NUMERIC" - | } - | } - | } - |} - """.stripMargin, - observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv") - - val selectedColumns = Seq("a_id", "featureWithNull") - val filteredDf = df.data.select(selectedColumns.head, selectedColumns.tail: _*) - - val expectedDf = ss.createDataFrame( - ss.sparkContext.parallelize( - Seq( - Row( - // a_id - "1", - // featureWithNull - mutable.WrappedArray.make(Array(Array("aa", "bb", "cc", "dd", "ee"), Array("a", "a", "a", "a", "a"))), - ), - Row( - // a_id - "2", - // f3eatureWithNull - mutable.WrappedArray.make(Array(Array("aa", "bb", "cc", "dd", "ee"), Array("UNK", "UNK", "UNK", "UNK", "UNK"))) - ), - Row( - // a_id - "3", - // featureWithNull - mutable.WrappedArray.make(Array(Array("aa", "bb", "cc", "dd", "ee"), Array("a", "a", "a", "a", "a")), - )))), - StructType( - List( - StructField("a_id", StringType, true), - StructField("featureWithNull", ArrayType(ArrayType(StringType, true), true), true) - ))) - - def cmpFunc(row: Row): String = row.get(0).toString - - FeathrTestUtils.assertDataFrameApproximatelyEquals(filteredDf, expectedDf, cmpFunc) - } - /* * Test features with null values. diff --git a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala index 64d2cee62..d2d2516fb 100644 --- a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala +++ b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/TestFeathrUdfPlugins.scala @@ -1,29 +1,41 @@ package com.linkedin.feathr.offline -import com.linkedin.feathr.common.FeatureTypes +import com.linkedin.feathr.common.FeatureValue import com.linkedin.feathr.offline.anchored.keyExtractor.AlienSourceKeyExtractorAdaptor import com.linkedin.feathr.offline.client.plugins.FeathrUdfPluginContext import com.linkedin.feathr.offline.derived.AlienDerivationFunctionAdaptor +import com.linkedin.feathr.offline.mvel.FeathrFeatureValueAsAlien import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext import com.linkedin.feathr.offline.plugins.{AlienFeatureValue, AlienFeatureValueTypeAdaptor} import com.linkedin.feathr.offline.util.FeathrTestUtils import org.apache.spark.sql.Row import org.apache.spark.sql.types.{FloatType, StringType, StructField, StructType} -import org.testng.Assert.assertEquals -import org.testng.annotations.Test - +import org.testng.Assert.{assertEquals, assertTrue} +import org.testng.annotations.{BeforeClass, Test} class TestFeathrUdfPlugins extends FeathrIntegTest { val MULTILINE_QUOTE = "\"\"\"" private val mvelContext = new FeathrExpressionExecutionContext() - // todo - support udf plugins through FCM - @Test (enabled = false) - def testMvelUdfPluginSupport: Unit = { + + @BeforeClass + override def setFeathrConfig(): Unit = { mvelContext.setupExecutorMvelContext(classOf[AlienFeatureValue], new AlienFeatureValueTypeAdaptor(), ss.sparkContext) FeathrUdfPluginContext.registerUdfAdaptor(new AlienDerivationFunctionAdaptor(), ss.sparkContext) FeathrUdfPluginContext.registerUdfAdaptor(new AlienSourceKeyExtractorAdaptor(), ss.sparkContext) + } + + @Test + def testFeatureValueWrapper(): Unit = { + val featureValue = new FeatureValue(2.0f) + val featureFeatureValueAsAlien = new FeathrFeatureValueAsAlien(featureValue) + assertTrue(mvelContext.canConvert(FeatureValue.getClass, featureFeatureValueAsAlien.getClass)) + assertEquals(mvelContext.convert(featureFeatureValueAsAlien, FeatureValue.getClass), featureValue) + } + + @Test (enabled = true) + def testMvelUdfPluginSupport: Unit = { val df = runLocalFeatureJoinForTest( joinConfigAsString = """ | features: { @@ -113,8 +125,6 @@ class TestFeathrUdfPlugins extends FeathrIntegTest { observationDataPath = "anchorAndDerivations/testMVELLoopExpFeature-observations.csv", mvelContext = Some(mvelContext)) - val f8Type = df.fdsMetadata.header.get.featureInfoMap.filter(_._1.getFeatureName == "f8").head._2.featureType.getFeatureType - assertEquals(f8Type, FeatureTypes.NUMERIC) val selectedColumns = Seq("a_id", "fA") val filteredDf = df.data.select(selectedColumns.head, selectedColumns.tail: _*) @@ -138,4 +148,6 @@ class TestFeathrUdfPlugins extends FeathrIntegTest { def cmpFunc(row: Row): String = row.get(0).toString FeathrTestUtils.assertDataFrameApproximatelyEquals(filteredDf, expectedDf, cmpFunc) } + + } diff --git a/feathr-impl/src/test/scala/com/linkedin/feathr/offline/mvel/FeathrFeatureValueAsAlien.scala b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/mvel/FeathrFeatureValueAsAlien.scala new file mode 100644 index 000000000..4d3d78909 --- /dev/null +++ b/feathr-impl/src/test/scala/com/linkedin/feathr/offline/mvel/FeathrFeatureValueAsAlien.scala @@ -0,0 +1,9 @@ +package com.linkedin.feathr.offline.mvel + +import com.linkedin.feathr.common.FeatureValue +import com.linkedin.feathr.offline.mvel.plugins.FeatureValueWrapper +import com.linkedin.feathr.offline.plugins.AlienFeatureValue + +class FeathrFeatureValueAsAlien(feathrFeatureValue: FeatureValue) extends AlienFeatureValue with FeatureValueWrapper[FeatureValue] { + override def getFeatureValue(): FeatureValue = feathrFeatureValue +} diff --git a/gradle.properties b/gradle.properties index 64500a2a0..bd4e6fb38 100644 --- a/gradle.properties +++ b/gradle.properties @@ -1,3 +1,3 @@ -version=1.0.4-rc2 +version=1.0.4-rc3 SONATYPE_AUTOMATIC_RELEASE=true POM_ARTIFACT_ID=feathr_2.12