Skip to content

Commit

Permalink
[SPARK-33398] Fix loading tree models prior to Spark 3.0
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
In https://github.com/apache/spark/pull/21632/files#diff-0fdae8a6782091746ed20ea43f77b639f9c6a5f072dd2f600fcf9a7b37db4f47, a new field `rawCount` was added into `NodeData`, which cause that a tree model trained in 2.4 can not be loaded in 3.0/3.1/master;
field `rawCount` is only used in training, and not used in `transform`/`predict`/`featureImportance`. So I just set it to -1L.

### Why are the changes needed?
to support load old tree model in 3.0/3.1/master

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
added testsuites

Closes apache#30889 from zhengruifeng/fix_tree_load.

Authored-by: Ruifeng Zheng <[email protected]>
Signed-off-by: Sean Owen <[email protected]>
  • Loading branch information
zhengruifeng authored and srowen committed Jan 3, 2021
1 parent 963c60f commit 6b7527e
Show file tree
Hide file tree
Showing 74 changed files with 122 additions and 20 deletions.
48 changes: 33 additions & 15 deletions mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,10 @@ import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter}
import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Dataset, SparkSession}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions.{col, lit, struct}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.VersionUtils
import org.apache.spark.util.collection.OpenHashMap

/**
Expand Down Expand Up @@ -401,8 +403,13 @@ private[ml] object DecisionTreeModelReadWrite {
}

val dataPath = new Path(path, "data").toString
val data = sparkSession.read.parquet(dataPath).as[NodeData]
buildTreeFromNodes(data.collect(), impurityType)
var df = sparkSession.read.parquet(dataPath)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
if (major.toInt < 3) {
df = df.withColumn("rawCount", lit(-1L))
}

buildTreeFromNodes(df.as[NodeData].collect(), impurityType)
}

/**
Expand Down Expand Up @@ -497,25 +504,36 @@ private[ml] object EnsembleModelReadWrite {
}

val treesMetadataPath = new Path(path, "treesMetadata").toString
val treesMetadataRDD: RDD[(Int, (Metadata, Double))] = sql.read.parquet(treesMetadataPath)
.select("treeID", "metadata", "weights").as[(Int, String, Double)].rdd.map {
case (treeID: Int, json: String, weights: Double) =>
val treesMetadataRDD = sql.read.parquet(treesMetadataPath)
.select("treeID", "metadata", "weights")
.as[(Int, String, Double)].rdd
.map { case (treeID: Int, json: String, weights: Double) =>
treeID -> ((DefaultParamsReader.parseMetadata(json, treeClassName), weights))
}
}

val treesMetadataWeights = treesMetadataRDD.sortByKey().values.collect()
val treesMetadata = treesMetadataWeights.map(_._1)
val treesWeights = treesMetadataWeights.map(_._2)

val dataPath = new Path(path, "data").toString
val nodeData: Dataset[EnsembleNodeData] =
sql.read.parquet(dataPath).as[EnsembleNodeData]
val rootNodesRDD: RDD[(Int, Node)] =
nodeData.rdd.map(d => (d.treeID, d.nodeData)).groupByKey().map {
case (treeID: Int, nodeData: Iterable[NodeData]) =>
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
var df = sql.read.parquet(dataPath)
val (major, minor) = VersionUtils.majorMinorVersion(metadata.sparkVersion)
if (major.toInt < 3) {
val newNodeDataCol = df.schema("nodeData").dataType match {
case StructType(fields) =>
val cols = fields.map(f => col(s"nodeData.${f.name}")) :+ lit(-1L).as("rawCount")
struct(cols: _*)
}
df = df.withColumn("nodeData", newNodeDataCol)
}

val rootNodesRDD = df.as[EnsembleNodeData].rdd
.map(d => (d.treeID, d.nodeData))
.groupByKey()
.map { case (treeID: Int, nodeData: Iterable[NodeData]) =>
treeID -> DecisionTreeModelReadWrite.buildTreeFromNodes(nodeData.toArray, impurityType)
}
val rootNodes: Array[Node] = rootNodesRDD.sortByKey().values.collect()
val rootNodes = rootNodesRDD.sortByKey().values.collect()
(metadata, treesMetadata.zip(rootNodes), treesWeights)
}

Expand Down
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"class":"org.apache.spark.ml.classification.DecisionTreeClassificationModel","timestamp":1608687929358,"sparkVersion":"2.4.7","uid":"dtc_bc7ad285bb73","paramMap":{},"defaultParamMap":{"impurity":"gini","maxDepth":5,"labelCol":"label","maxMemoryInMB":256,"featuresCol":"features","predictionCol":"prediction","minInfoGain":0.0,"seed":159147643,"rawPredictionCol":"rawPrediction","minInstancesPerNode":1,"cacheNodeIds":false,"probabilityCol":"probability","maxBins":32,"checkpointInterval":10},"numFeatures":692,"numClasses":2}
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"class":"org.apache.spark.ml.regression.DecisionTreeRegressionModel","timestamp":1608687932847,"sparkVersion":"2.4.7","uid":"dtr_c16a90fcdaf8","paramMap":{},"defaultParamMap":{"labelCol":"label","checkpointInterval":10,"minInfoGain":0.0,"maxMemoryInMB":256,"minInstancesPerNode":1,"maxBins":32,"seed":926680331,"cacheNodeIds":false,"maxDepth":5,"predictionCol":"prediction","featuresCol":"features","impurity":"variance"},"numFeatures":692}
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"class":"org.apache.spark.ml.classification.GBTClassificationModel","timestamp":1608687932103,"sparkVersion":"2.4.7","uid":"gbtc_81db008b4f25","paramMap":{"maxIter":2},"defaultParamMap":{"seed":-1287390502,"maxMemoryInMB":256,"stepSize":0.1,"validationTol":0.01,"maxBins":32,"checkpointInterval":10,"predictionCol":"prediction","lossType":"logistic","rawPredictionCol":"rawPrediction","featuresCol":"features","cacheNodeIds":false,"maxIter":20,"featureSubsetStrategy":"all","impurity":"gini","minInstancesPerNode":1,"minInfoGain":0.0,"maxDepth":5,"subsamplingRate":1.0,"labelCol":"label","probabilityCol":"probability"},"numFeatures":692,"numTrees":2}
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"class":"org.apache.spark.ml.regression.GBTRegressionModel","timestamp":1608687942434,"sparkVersion":"2.4.7","uid":"gbtr_0a74cb2536ff","paramMap":{"maxIter":2},"defaultParamMap":{"impurity":"variance","maxMemoryInMB":256,"maxDepth":5,"subsamplingRate":1.0,"validationTol":0.01,"labelCol":"label","maxIter":20,"checkpointInterval":10,"minInfoGain":0.0,"predictionCol":"prediction","stepSize":0.1,"cacheNodeIds":false,"lossType":"squared","seed":-131597770,"featureSubsetStrategy":"all","featuresCol":"features","minInstancesPerNode":1,"maxBins":32},"numFeatures":692,"numTrees":2}
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"class":"org.apache.spark.ml.classification.RandomForestClassificationModel","timestamp":1608687930713,"sparkVersion":"2.4.7","uid":"rfc_db1adb353f1e","paramMap":{"numTrees":2},"defaultParamMap":{"impurity":"gini","predictionCol":"prediction","numTrees":20,"maxDepth":5,"featureSubsetStrategy":"auto","subsamplingRate":1.0,"featuresCol":"features","checkpointInterval":10,"rawPredictionCol":"rawPrediction","cacheNodeIds":false,"labelCol":"label","seed":207336481,"probabilityCol":"probability","maxBins":32,"minInstancesPerNode":1,"minInfoGain":0.0,"maxMemoryInMB":256},"numFeatures":692,"numClasses":2,"numTrees":2}
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"class":"org.apache.spark.ml.regression.RandomForestRegressionModel","timestamp":1608687933536,"sparkVersion":"2.4.7","uid":"rfr_d946d96b7ff0","paramMap":{"numTrees":2},"defaultParamMap":{"numTrees":20,"featureSubsetStrategy":"auto","maxDepth":5,"minInstancesPerNode":1,"labelCol":"label","cacheNodeIds":false,"checkpointInterval":10,"featuresCol":"features","maxMemoryInMB":256,"predictionCol":"prediction","minInfoGain":0.0,"subsamplingRate":1.0,"impurity":"variance","seed":235498149,"maxBins":32},"numFeatures":692,"numTrees":2}
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,18 @@ class DecisionTreeClassifierSuite extends MLTest with DefaultReadWriteTest {

testDefaultReadWrite(model)
}

test("SPARK-33398: Load DecisionTreeClassificationModel prior to Spark 3.0") {
val path = testFile("ml-models/dtc-2.4.7")
val model = DecisionTreeClassificationModel.load(path)
assert(model.numClasses === 2)
assert(model.numFeatures === 692)
assert(model.numNodes === 5)

val metadata = spark.read.json(s"$path/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr === "2.4.7")
}
}

private[ml] object DecisionTreeClassifierSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -545,6 +545,20 @@ class GBTClassifierSuite extends MLTest with DefaultReadWriteTest {
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
allParamSettings, checkModelData)
}

test("SPARK-33398: Load GBTClassificationModel prior to Spark 3.0") {
val path = testFile("ml-models/gbtc-2.4.7")
val model = GBTClassificationModel.load(path)
assert(model.numClasses === 2)
assert(model.numFeatures === 692)
assert(model.getNumTrees === 2)
assert(model.totalNumNodes === 22)
assert(model.trees.map(_.numNodes) === Array(5, 17))

val metadata = spark.read.json(s"$path/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr === "2.4.7")
}
}

private object GBTClassifierSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class MultilayerPerceptronClassifierSuite extends MLTest with DefaultReadWriteTe

val metadata = spark.read.json(s"$mlpPath/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr == "2.4.4")
assert(sparkVersionStr === "2.4.4")
}

test("summary and training summary") {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree._
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.regression.{LabeledPoint => OldLabeledPoint}
import org.apache.spark.mllib.tree.{EnsembleTestHelper, RandomForest => OldRandomForest}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -429,6 +429,20 @@ class RandomForestClassifierSuite extends MLTest with DefaultReadWriteTest {
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
allParamSettings, checkModelData)
}

test("SPARK-33398: Load RandomForestClassificationModel prior to Spark 3.0") {
val path = testFile("ml-models/rfc-2.4.7")
val model = RandomForestClassificationModel.load(path)
assert(model.numClasses === 2)
assert(model.numFeatures === 692)
assert(model.getNumTrees === 2)
assert(model.totalNumNodes === 10)
assert(model.trees.map(_.numNodes) === Array(3, 7))

val metadata = spark.read.json(s"$path/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr === "2.4.7")
}
}

private object RandomForestClassifierSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {

val metadata = spark.read.json(s"$hashingTFPath/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr == "2.4.4")
assert(sparkVersionStr === "2.4.4")

intercept[IllegalArgumentException] {
loadedHashingTF.save(hashingTFPath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,6 @@ class StringIndexerSuite extends MLTest with DefaultReadWriteTest {

val metadata = spark.read.json(s"$modelPath/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr == "2.4.4")
assert(sparkVersionStr === "2.4.4")
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package org.apache.spark.ml.regression

import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree.impl.TreeTests
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
Expand Down Expand Up @@ -236,6 +236,20 @@ class DecisionTreeRegressorSuite extends MLTest with DefaultReadWriteTest {
TreeTests.allParamSettings ++ Map("maxDepth" -> 0),
TreeTests.allParamSettings ++ Map("maxDepth" -> 0), checkModelData)
}

test("SPARK-33398: Load DecisionTreeRegressionModel prior to Spark 3.0") {
val path = testFile("ml-models/dtr-2.4.7")
val model = DecisionTreeRegressionModel.load(path)
assert(model.numFeatures === 692)
assert(model.numNodes === 5)
assert(model.featureImportances ~==
Vectors.sparse(692, Array(100, 434),
Array(0.03987240829346093, 0.960127591706539)) absTol 1e-4)

val metadata = spark.read.json(s"$path/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr === "2.4.7")
}
}

private[ml] object DecisionTreeRegressorSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,18 @@ class GBTRegressorSuite extends MLTest with DefaultReadWriteTest {
testEstimatorAndModelReadWrite(gbt, continuousData, allParamSettings,
allParamSettings, checkModelData)
}

test("SPARK-33398: Load GBTRegressionModel prior to Spark 3.0") {
val path = testFile("ml-models/gbtr-2.4.7")
val model = GBTRegressionModel.load(path)
assert(model.numFeatures === 692)
assert(model.totalNumNodes === 6)
assert(model.trees.map(_.numNodes) === Array(5, 1))

val metadata = spark.read.json(s"$path/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr === "2.4.7")
}
}

private object GBTRegressorSuite extends SparkFunSuite {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,18 @@ class RandomForestRegressorSuite extends MLTest with DefaultReadWriteTest{
testEstimatorAndModelReadWrite(rf, continuousData, allParamSettings,
allParamSettings, checkModelData)
}

test("SPARK-33398: Load RandomForestRegressionModel prior to Spark 3.0") {
val path = testFile("ml-models/rfr-2.4.7")
val model = RandomForestRegressionModel.load(path)
assert(model.numFeatures === 692)
assert(model.totalNumNodes === 8)
assert(model.trees.map(_.numNodes) === Array(5, 3))

val metadata = spark.read.json(s"$path/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr === "2.4.7")
}
}

private object RandomForestRegressorSuite extends SparkFunSuite {
Expand Down

0 comments on commit 6b7527e

Please sign in to comment.