forked from onnx/onnxmltools
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: SparkML MultilayerPerceptronClassifier (onnx#570)
* feat: SparkML MultilayerPerceptronClassifier Signed-off-by: Jason Wang <[email protected]> * imports Signed-off-by: Jason Wang <[email protected]>
- Loading branch information
Showing
4 changed files
with
175 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
101 changes: 101 additions & 0 deletions
101
onnxmltools/convert/sparkml/operator_converters/mlp_classifier.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
from pyspark.ml.classification import MultilayerPerceptronClassificationModel | ||
|
||
from ...common._registration import register_converter, register_shape_calculator | ||
from ...common.data_types import Int64TensorType, FloatTensorType | ||
from ...common.utils import check_input_and_output_numbers, check_input_and_output_types | ||
from ...common._topology import Operator, Scope, ModelComponentContainer | ||
from ....proto import onnx_proto | ||
from typing import List | ||
import numpy as np | ||
|
||
|
||
def convert_sparkml_mlp_classifier(scope: Scope, operator: Operator, container: ModelComponentContainer): | ||
op: MultilayerPerceptronClassificationModel = operator.raw_operator | ||
layers: List[int] = op.getLayers() | ||
weights: np.ndarray = op.weights.toArray() | ||
|
||
offset = 0 | ||
|
||
input: str | ||
for i in range(len(layers) - 1): | ||
weight_matrix = weights[offset : offset + layers[i] * layers[i + 1]].reshape([layers[i], layers[i + 1]]) | ||
offset += layers[i] * layers[i + 1] | ||
bias_vector = weights[offset : offset + layers[i + 1]] | ||
offset += layers[i + 1] | ||
|
||
if i == 0: | ||
input = operator.inputs[0].full_name | ||
|
||
weight_variable = scope.get_unique_variable_name("w") | ||
container.add_initializer( | ||
weight_variable, | ||
onnx_proto.TensorProto.FLOAT, | ||
weight_matrix.shape, | ||
weight_matrix.flatten().astype(np.float32), | ||
) | ||
|
||
bias_variable = scope.get_unique_variable_name("b") | ||
container.add_initializer( | ||
bias_variable, onnx_proto.TensorProto.FLOAT, bias_vector.shape, bias_vector.astype(np.float32), | ||
) | ||
|
||
gemm_output_variable = scope.get_unique_variable_name("gemm_output") | ||
container.add_node( | ||
op_type="Gemm", | ||
inputs=[input, weight_variable, bias_variable], | ||
outputs=[gemm_output_variable], | ||
op_version=7, | ||
name=scope.get_unique_operator_name("Gemm"), | ||
) | ||
|
||
if i == len(layers) - 2: | ||
container.add_node( | ||
op_type="Softmax", | ||
inputs=[gemm_output_variable], | ||
outputs=[operator.outputs[1].full_name], | ||
op_version=1, | ||
name=scope.get_unique_operator_name("Softmax"), | ||
) | ||
else: | ||
input = scope.get_unique_variable_name("activation_output") | ||
container.add_node( | ||
op_type="Sigmoid", | ||
inputs=[gemm_output_variable], | ||
outputs=[input], | ||
op_version=1, | ||
name=scope.get_unique_operator_name("Sigmoid"), | ||
) | ||
|
||
container.add_node( | ||
"ArgMax", | ||
[operator.outputs[1].full_name], | ||
[operator.outputs[0].full_name], | ||
name=scope.get_unique_operator_name("ArgMax"), | ||
axis=1, | ||
keepdims = 0, | ||
) | ||
|
||
|
||
register_converter("pyspark.ml.classification.MultilayerPerceptronClassificationModel", convert_sparkml_mlp_classifier) | ||
|
||
|
||
def calculate_mlp_classifier_output_shapes(operator: Operator): | ||
op: MultilayerPerceptronClassificationModel = operator.raw_operator | ||
|
||
check_input_and_output_numbers(operator, input_count_range=1, output_count_range=[1, 2]) | ||
check_input_and_output_types(operator, good_input_types=[FloatTensorType, Int64TensorType]) | ||
|
||
if len(operator.inputs[0].type.shape) != 2: | ||
raise RuntimeError("Input must be a [N, C]-tensor") | ||
|
||
N = operator.inputs[0].type.shape[0] | ||
operator.outputs[0].type = Int64TensorType(shape=[N]) | ||
class_count = op.numClasses | ||
operator.outputs[1].type = FloatTensorType([N, class_count]) | ||
|
||
|
||
register_shape_calculator( | ||
"pyspark.ml.classification.MultilayerPerceptronClassificationModel", calculate_mlp_classifier_output_shapes | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# SPDX-License-Identifier: Apache-2.0 | ||
|
||
import sys | ||
import unittest | ||
import inspect | ||
import os | ||
import numpy | ||
import pandas | ||
from pyspark.ml.classification import MultilayerPerceptronClassifier, MultilayerPerceptronClassificationModel | ||
from pyspark.ml.linalg import VectorUDT, SparseVector | ||
from onnx.defs import onnx_opset_version | ||
from onnxconverter_common.onnx_ex import DEFAULT_OPSET_NUMBER | ||
from onnxmltools import convert_sparkml | ||
from onnxmltools.convert.common.data_types import FloatTensorType | ||
from tests.sparkml.sparkml_test_utils import save_data_models, run_onnx_model, compare_results | ||
from tests.sparkml import SparkMlTestCase | ||
|
||
|
||
TARGET_OPSET = min(DEFAULT_OPSET_NUMBER, onnx_opset_version()) | ||
|
||
|
||
class TestSparkmlMLPClassifier(SparkMlTestCase): | ||
@unittest.skipIf(sys.version_info < (3, 8), reason="pickle fails on python 3.7") | ||
def test_model_mlp_classifier_binary_class(self): | ||
this_script_dir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe()))) | ||
input_path = os.path.join(this_script_dir, "data", "sample_libsvm_data.txt") | ||
original_data = self.spark.read.format("libsvm").load(input_path) | ||
# | ||
# truncate the features | ||
# | ||
self.spark.udf.register( | ||
"truncateFeatures", lambda x: SparseVector(100, range(0, 100), x.toArray()[30:130]), VectorUDT() | ||
) | ||
|
||
data = original_data.selectExpr("label", "truncateFeatures(features) as features") | ||
|
||
mlp = MultilayerPerceptronClassifier(maxIter=100, tol=0.0001, seed=137, layers=[100, 20, 5, 2],) | ||
model: MultilayerPerceptronClassificationModel = mlp.fit(data) | ||
|
||
# the name of the input for Logistic Regression is 'features' | ||
C = model.numFeatures | ||
model_onnx = convert_sparkml( | ||
model, | ||
"sparkml multilayer perceptron classifier", | ||
[("features", FloatTensorType([None, C]))], | ||
target_opset=TARGET_OPSET, | ||
) | ||
|
||
self.assertTrue(model_onnx is not None) | ||
|
||
# run the model | ||
predicted = model.transform(data) | ||
# predicted.select("prediction", "probability", "label").show(100, truncate=False) | ||
|
||
data_np = data.toPandas().features.apply(lambda x: pandas.Series(x.toArray())).values.astype(numpy.float32) | ||
expected = [ | ||
predicted.toPandas().prediction.values.astype(numpy.float32), | ||
predicted.toPandas().probability.apply(lambda x: pandas.Series(x.toArray())).values.astype(numpy.float32), | ||
] | ||
|
||
paths = save_data_models(data_np, expected, model, model_onnx, basename="SparkmlMLPClassifier") | ||
onnx_model_path = paths[-1] | ||
output, output_shapes = run_onnx_model(["prediction", "probability"], data_np, onnx_model_path) | ||
compare_results(expected, output, decimal=5) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |