From a866b188a14c9c85060bdde262a048bdcb59a8af Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Mon, 29 Jul 2024 17:16:29 -0700 Subject: [PATCH 1/3] Add BF16 test for python backend --- qa/L0_backend_python/python_test.py | 30 +++++++ qa/L0_backend_python/test.sh | 2 + qa/python_models/identity_bf16/config.pbtxt | 51 ++++++++++++ qa/python_models/identity_bf16/model.py | 89 +++++++++++++++++++++ 4 files changed, 172 insertions(+) create mode 100644 qa/python_models/identity_bf16/config.pbtxt create mode 100644 qa/python_models/identity_bf16/model.py diff --git a/qa/L0_backend_python/python_test.py b/qa/L0_backend_python/python_test.py index a7a44608cf..c2512600e2 100755 --- a/qa/L0_backend_python/python_test.py +++ b/qa/L0_backend_python/python_test.py @@ -365,6 +365,36 @@ def test_bool(self): self.assertIsNotNone(output0) self.assertTrue(np.all(output0 == input_data)) + def test_bf16(self): + model_name = "identity_bf16" + shape = [2, 2] + with self._shm_leak_detector.Probe() as shm_probe: + with httpclient.InferenceServerClient( + f"{_tritonserver_ipaddr}:8000" + ) as client: + # NOTE: Client will truncate FP32 to BF16 internally + # since numpy has no built-in BF16 representation. + np_input = np.ones(shape, dtype=np.float32) + inputs = [ + httpclient.InferInput( + "INPUT0", np_input.shape, "BF16" + ).set_data_from_numpy(np_input) + ] + result = client.infer(model_name, inputs) + + # Assert that Triton correctly returned a BF16 tensor. + response = result.get_response() + triton_output = response["outputs"][0] + triton_dtype = triton_output["datatype"] + self.assertEqual(triton_dtype, "BF16") + + np_output = result.as_numpy("OUTPUT0") + self.assertIsNotNone(np_output) + # BF16 tensors are held in FP32 when converted to numpy due to + # lack of native BF16 support in numpy, so verify that. + self.assertEqual(np_output.dtype, np.float32) + self.assertTrue(np.allclose(np_output, np_input)) + def test_infer_pytorch(self): # FIXME: This model requires torch. Because windows tests are not run in a docker # environment with torch installed, we need to think about how we want to install diff --git a/qa/L0_backend_python/test.sh b/qa/L0_backend_python/test.sh index 0e0240cd95..0cc34befe1 100755 --- a/qa/L0_backend_python/test.sh +++ b/qa/L0_backend_python/test.sh @@ -95,6 +95,8 @@ fi mkdir -p models/identity_fp32/1/ cp ../python_models/identity_fp32/model.py ./models/identity_fp32/1/model.py cp ../python_models/identity_fp32/config.pbtxt ./models/identity_fp32/config.pbtxt +cp ../python_models/identity_bf16/model.py ./models/identity_bf16/1/model.py +cp ../python_models/identity_bf16/config.pbtxt ./models/identity_bf16/config.pbtxt RET=0 cp -r ./models/identity_fp32 ./models/identity_uint8 diff --git a/qa/python_models/identity_bf16/config.pbtxt b/qa/python_models/identity_bf16/config.pbtxt new file mode 100644 index 0000000000..e4d7df06c1 --- /dev/null +++ b/qa/python_models/identity_bf16/config.pbtxt @@ -0,0 +1,51 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +backend: "python" +max_batch_size: 64 + +input [ + { + name: "INPUT0" + data_type: TYPE_BF16 + dims: [ -1 ] + } +] + +output [ + { + name: "OUTPUT0" + data_type: TYPE_BF16 + dims: [ -1 ] + } +] + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/qa/python_models/identity_bf16/model.py b/qa/python_models/identity_bf16/model.py new file mode 100644 index 0000000000..3a34a2f9a4 --- /dev/null +++ b/qa/python_models/identity_bf16/model.py @@ -0,0 +1,89 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import numpy as np +import torch +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + # You must parse model_config. JSON string is not parsed here + self.model_config = json.loads(args["model_config"]) + + # Get tensor configurations for testing/validation + self.input0_config = pb_utils.get_input_config_by_name( + self.model_config, "INPUT0" + ) + self.output0_config = pb_utils.get_output_config_by_name( + self.model_config, "OUTPUT0" + ) + + def validate_bf16_tensor(self, tensor, tensor_config): + # I/O datatypes can be queried from the model config if needed + dtype = tensor_config["data_type"] + if dtype == "TYPE_BF16": + # Converting BF16 tensors to numpy is not supported, and DLPack + # should be used instead via to_dlpack and from_dlpack. + try: + _ = tensor.as_numpy() + except pb_utils.TritonModelException as e: + expected_error = "tensor dtype is bf16 and cannot be converted to numpy" + assert expected_error in str(e).lower() + else: + raise Exception("Expected BF16 conversion to numpy to fail") + else: + raise Exception(f"Expected a BF16 tensor, but got {dtype} instead.") + + def execute(self, requests): + """ + Identity model in Python backend with example BF16 and PyTorch usage. + """ + responses = [] + for request in requests: + input_tensor = pb_utils.get_input_tensor_by_name(request, "INPUT0") + + # Numpy does not support BF16, so use DLPack instead. + bf16_dlpack = input_tensor.to_dlpack() + + # OPTIONAL: The tensor can be converted to other dlpack-compatible + # frameworks like PyTorch and TensorFlow with their dlpack utilities. + torch_tensor = torch.utils.dlpack.from_dlpack(bf16_dlpack) + + # When complete, convert back to a pb_utils.Tensor via DLPack. + output_tensor = pb_utils.Tensor.from_dlpack( + "OUTPUT0", torch.utils.dlpack.to_dlpack(torch_tensor) + ) + responses.append(pb_utils.InferenceResponse([output_tensor])) + + # NOTE: The following helper function is for testing and example + # purposes only, you should remove this in practice. + self.validate_bf16_tensor(input_tensor, self.input0_config) + self.validate_bf16_tensor(output_tensor, self.output0_config) + + return responses From 94328b766fedeab50f6ac2d64cdbe677b4cb3f0b Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Mon, 29 Jul 2024 17:48:54 -0700 Subject: [PATCH 2/3] Invert check for less nesting and remove unused import --- qa/python_models/identity_bf16/model.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/qa/python_models/identity_bf16/model.py b/qa/python_models/identity_bf16/model.py index 3a34a2f9a4..57756073b9 100644 --- a/qa/python_models/identity_bf16/model.py +++ b/qa/python_models/identity_bf16/model.py @@ -26,7 +26,6 @@ import json -import numpy as np import torch import triton_python_backend_utils as pb_utils @@ -47,19 +46,19 @@ def initialize(self, args): def validate_bf16_tensor(self, tensor, tensor_config): # I/O datatypes can be queried from the model config if needed dtype = tensor_config["data_type"] - if dtype == "TYPE_BF16": - # Converting BF16 tensors to numpy is not supported, and DLPack - # should be used instead via to_dlpack and from_dlpack. - try: - _ = tensor.as_numpy() - except pb_utils.TritonModelException as e: - expected_error = "tensor dtype is bf16 and cannot be converted to numpy" - assert expected_error in str(e).lower() - else: - raise Exception("Expected BF16 conversion to numpy to fail") - else: + if dtype != "TYPE_BF16": raise Exception(f"Expected a BF16 tensor, but got {dtype} instead.") + # Converting BF16 tensors to numpy is not supported, and DLPack + # should be used instead via to_dlpack and from_dlpack. + try: + _ = tensor.as_numpy() + except pb_utils.TritonModelException as e: + expected_error = "tensor dtype is bf16 and cannot be converted to numpy" + assert expected_error in str(e).lower() + else: + raise Exception("Expected BF16 conversion to numpy to fail") + def execute(self, requests): """ Identity model in Python backend with example BF16 and PyTorch usage. From c41a83d21c6db71c21fb1a5f1b00884c9e799f6a Mon Sep 17 00:00:00 2001 From: Ryan McCormick Date: Tue, 30 Jul 2024 09:03:21 -0700 Subject: [PATCH 3/3] Fix folder --- qa/L0_backend_python/test.sh | 1 + 1 file changed, 1 insertion(+) diff --git a/qa/L0_backend_python/test.sh b/qa/L0_backend_python/test.sh index 0cc34befe1..65767419f2 100755 --- a/qa/L0_backend_python/test.sh +++ b/qa/L0_backend_python/test.sh @@ -95,6 +95,7 @@ fi mkdir -p models/identity_fp32/1/ cp ../python_models/identity_fp32/model.py ./models/identity_fp32/1/model.py cp ../python_models/identity_fp32/config.pbtxt ./models/identity_fp32/config.pbtxt +mkdir -p models/identity_bf16/1/ cp ../python_models/identity_bf16/model.py ./models/identity_bf16/1/model.py cp ../python_models/identity_bf16/config.pbtxt ./models/identity_bf16/config.pbtxt RET=0