From 7cf932be414e1d16e41ac07faa4b5d149f5214ec Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Tue, 7 May 2024 12:14:40 -0700 Subject: [PATCH 1/9] add byte checking --- src/infer_request.cc | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/infer_request.cc b/src/infer_request.cc index 1015db21d..3dc00e4ff 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1134,6 +1134,22 @@ InferenceRequest::Normalize() implicit_batch_note); } } + // Matching incoming request's shape and byte size to make sure the + // payload contains correct number of elements + { + const auto& byte_size = input.Data()->TotalByteSize(); + const auto& data_type = input.DType(); + const auto& input_dims = *shape; + const auto expected_byte_size = + triton::common::GetByteSize(data_type, input_dims); + if (byte_size != expected_byte_size) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input byte size mismatch for input '" + pr.first + + "' for model '" + ModelName() + "'. Expected " + byte_size + + ", got " + expected_byte_size); + } + } // If there is a reshape for this input then adjust them to // match the reshape. As reshape may have variable-size From 5728fed6b2557c5a59a2c6b5000e60a575923de5 Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Tue, 7 May 2024 16:07:21 -0700 Subject: [PATCH 2/9] update core repo includes and conversions --- src/infer_request.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 3dc00e4ff..44d43b8d5 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -28,6 +28,7 @@ #include #include +#include #include "constants.h" #include "model.h" @@ -1137,17 +1138,19 @@ InferenceRequest::Normalize() // Matching incoming request's shape and byte size to make sure the // payload contains correct number of elements { - const auto& byte_size = input.Data()->TotalByteSize(); + const size_t& byte_size = input.Data()->TotalByteSize(); const auto& data_type = input.DType(); const auto& input_dims = *shape; - const auto expected_byte_size = + const int64_t expected_byte_size = triton::common::GetByteSize(data_type, input_dims); - if (byte_size != expected_byte_size) { + if ((byte_size > INT_MAX) || + (std::static_cast(byte_size) != expected_byte_size)) { return Status( Status::Code::INVALID_ARG, LogRequest() + "input byte size mismatch for input '" + pr.first + - "' for model '" + ModelName() + "'. Expected " + byte_size + - ", got " + expected_byte_size); + "' for model '" + ModelName() + "'. Expected " + + std::to_string(byte_size) + ", got " + + std::to_string(expected_byte_size)); } } From 63fd3e411cae423180c436edbd97a850160ec47e Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Tue, 7 May 2024 16:26:13 -0700 Subject: [PATCH 3/9] remove std from static cast --- src/infer_request.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 44d43b8d5..8fc4a3d1d 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1144,7 +1144,7 @@ InferenceRequest::Normalize() const int64_t expected_byte_size = triton::common::GetByteSize(data_type, input_dims); if ((byte_size > INT_MAX) || - (std::static_cast(byte_size) != expected_byte_size)) { + (static_cast(byte_size) != expected_byte_size)) { return Status( Status::Code::INVALID_ARG, LogRequest() + "input byte size mismatch for input '" + pr.first + From ec13a8c9c0045eaf70f24d93f2e4e5590ecb556d Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Wed, 8 May 2024 18:41:03 -0700 Subject: [PATCH 4/9] update infer_request with actual dimension size and add count for STRING --- src/infer_request.cc | 49 ++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 18 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 8fc4a3d1d..585729b1e 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1135,24 +1135,6 @@ InferenceRequest::Normalize() implicit_batch_note); } } - // Matching incoming request's shape and byte size to make sure the - // payload contains correct number of elements - { - const size_t& byte_size = input.Data()->TotalByteSize(); - const auto& data_type = input.DType(); - const auto& input_dims = *shape; - const int64_t expected_byte_size = - triton::common::GetByteSize(data_type, input_dims); - if ((byte_size > INT_MAX) || - (static_cast(byte_size) != expected_byte_size)) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "input byte size mismatch for input '" + pr.first + - "' for model '" + ModelName() + "'. Expected " + - std::to_string(byte_size) + ", got " + - std::to_string(expected_byte_size)); - } - } // If there is a reshape for this input then adjust them to // match the reshape. As reshape may have variable-size @@ -1188,6 +1170,37 @@ InferenceRequest::Normalize() input.MutableShapeWithBatchDim()->push_back(d); } } + // Matching incoming request's shape and byte size to make sure the + // payload contains correct number of elements. + // Note: Since we're using normalized input.ShapeWithBatchDim() here, + // make sure that all the normalization is before the check. + { + const size_t& byte_size = input.Data()->TotalByteSize(); + const auto& data_type = input.DType(); + const auto& input_dims = input.ShapeWithBatchDim(); + int64_t expected_byte_size = INT_MAX; + // Because Triton expects STRING type to be in special format + // (prepend 4 bytes to specify string length), so need to add all the + // first 4 bytes for each element to find expected byte size + if (data_type == inference::DataType::TYPE_STRING) { + int64_t element_count = triton::common::GetElementCount(input_dims); + expected_byte_size = 0; + for (size_t i = 0; i < element_count; ++i) { + expected_byte_size += 4; // FIXME: Actually add the byte size + } + } else { + expected_byte_size = triton::common::GetByteSize(data_type, input_dims); + } + if ((byte_size > INT_MAX) || + (static_cast(byte_size) != expected_byte_size)) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "input byte size mismatch for input '" + pr.first + + "' for model '" + ModelName() + "'. Expected " + + std::to_string(expected_byte_size) + ", got " + + std::to_string(byte_size)); + } + } } return Status::Success; From f2e4846c23fc5d5a642d89fea5e8bbad7330335a Mon Sep 17 00:00:00 2001 From: Katherine Yang Date: Thu, 9 May 2024 10:48:32 -0700 Subject: [PATCH 5/9] fix bug --- src/infer_request.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 585729b1e..fed087263 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1185,7 +1185,7 @@ InferenceRequest::Normalize() if (data_type == inference::DataType::TYPE_STRING) { int64_t element_count = triton::common::GetElementCount(input_dims); expected_byte_size = 0; - for (size_t i = 0; i < element_count; ++i) { + for (int i = 0; i < element_count; ++i) { expected_byte_size += 4; // FIXME: Actually add the byte size } } else { From 229c92ccc978d92d8e1855eef6d6e10b3f41a15e Mon Sep 17 00:00:00 2001 From: Yingge He Date: Tue, 14 May 2024 08:46:06 -0700 Subject: [PATCH 6/9] Calculate string input --- src/infer_request.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index fed087263..9b39ea831 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1186,7 +1186,11 @@ InferenceRequest::Normalize() int64_t element_count = triton::common::GetElementCount(input_dims); expected_byte_size = 0; for (int i = 0; i < element_count; ++i) { - expected_byte_size += 4; // FIXME: Actually add the byte size + BufferAttributes* buffer_attributes; + const char* content = input.Data()->BufferAt(i, &buffer_attributes); + const uint32_t str_len = + *(reinterpret_cast(content)); + expected_byte_size += sizeof(uint32_t) + str_len; } } else { expected_byte_size = triton::common::GetByteSize(data_type, input_dims); From 6c61b2a9a0a6fc6bf88a2b9f324a9a5c8f7347cb Mon Sep 17 00:00:00 2001 From: Yingge He Date: Tue, 14 May 2024 13:48:19 -0700 Subject: [PATCH 7/9] Fix L0_infer tests --- src/infer_request.cc | 46 ++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 42 insertions(+), 4 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 9b39ea831..29219e442 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1184,13 +1184,51 @@ InferenceRequest::Normalize() // first 4 bytes for each element to find expected byte size if (data_type == inference::DataType::TYPE_STRING) { int64_t element_count = triton::common::GetElementCount(input_dims); + int64_t element_idx = 0; expected_byte_size = 0; - for (int i = 0; i < element_count; ++i) { + for (size_t i = 0; i < input.Data()->BufferCount(); ++i) { BufferAttributes* buffer_attributes; const char* content = input.Data()->BufferAt(i, &buffer_attributes); - const uint32_t str_len = - *(reinterpret_cast(content)); - expected_byte_size += sizeof(uint32_t) + str_len; + size_t content_byte_size = input.Data()->TotalByteSize(); + + while (content_byte_size >= sizeof(uint32_t)) { + if (element_idx >= element_count) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "unexpected number of string elements " + + std::to_string(element_idx + 1) + + " for inference input '" + pr.first + "', expecting " + + std::to_string(element_count)); + } + + const uint32_t len = *(reinterpret_cast(content)); + content += sizeof(uint32_t); + content_byte_size -= sizeof(uint32_t); + expected_byte_size += sizeof(uint32_t); + + if (content_byte_size < len) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + + "incomplete string data for inference input '" + + pr.first + "', expecting string of length " + + std::to_string(len) + " but only " + + std::to_string(content_byte_size) + " bytes available"); + } + + content += len; + content_byte_size -= len; + expected_byte_size += len; + element_idx++; + } + } + + if (element_idx != element_count) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "expected " + std::to_string(element_count) + + " strings for inference input '" + pr.first + "', got " + + std::to_string(element_idx)); } } else { expected_byte_size = triton::common::GetByteSize(data_type, input_dims); From 08b11c096ad53e6bcb0bec9bf8b9a22fb2cefa20 Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 17 May 2024 00:00:38 -0700 Subject: [PATCH 8/9] Fix bug --- src/infer_request.cc | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 29219e442..2a716bc1b 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1187,9 +1187,11 @@ InferenceRequest::Normalize() int64_t element_idx = 0; expected_byte_size = 0; for (size_t i = 0; i < input.Data()->BufferCount(); ++i) { - BufferAttributes* buffer_attributes; - const char* content = input.Data()->BufferAt(i, &buffer_attributes); - size_t content_byte_size = input.Data()->TotalByteSize(); + size_t content_byte_size; + TRITONSERVER_MemoryType content_memory_type; + int64_t content_memory_id; + const char* content = input.Data()->BufferAt( + i, &content_byte_size, &content_memory_type, &content_memory_id); while (content_byte_size >= sizeof(uint32_t)) { if (element_idx >= element_count) { From 80594106d2b0ed172c23e564830ac4c46f635e4a Mon Sep 17 00:00:00 2001 From: Yingge He Date: Fri, 17 May 2024 10:52:36 -0700 Subject: [PATCH 9/9] Move string validation into helper function. --- src/infer_request.cc | 118 +++++++++++++++++++++++-------------------- src/infer_request.h | 4 ++ 2 files changed, 68 insertions(+), 54 deletions(-) diff --git a/src/infer_request.cc b/src/infer_request.cc index 2a716bc1b..0c85051ff 100644 --- a/src/infer_request.cc +++ b/src/infer_request.cc @@ -1070,13 +1070,14 @@ InferenceRequest::Normalize() const inference::ModelInput* input_config; RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config)); + auto& input_id = pr.first; auto& input = pr.second; auto shape = input.MutableShape(); if (input.DType() != input_config->data_type()) { return Status( Status::Code::INVALID_ARG, - LogRequest() + "inference input '" + pr.first + "' data-type is '" + + LogRequest() + "inference input '" + input_id + "' data-type is '" + std::string( triton::common::DataTypeToProtocolString(input.DType())) + "', but model '" + ModelName() + "' expects '" + @@ -1099,7 +1100,7 @@ InferenceRequest::Normalize() Status::Code::INVALID_ARG, LogRequest() + "All input dimensions should be specified for input '" + - pr.first + "' for model '" + ModelName() + "', got " + + input_id + "' for model '" + ModelName() + "', got " + triton::common::DimsListToString(input.OriginalShape())); } else if ( (config_dims[i] != triton::common::WILDCARD_DIM) && @@ -1128,7 +1129,7 @@ InferenceRequest::Normalize() } return Status( Status::Code::INVALID_ARG, - LogRequest() + "unexpected shape for input '" + pr.first + + LogRequest() + "unexpected shape for input '" + input_id + "' for model '" + ModelName() + "'. Expected " + triton::common::DimsListToString(full_dims) + ", got " + triton::common::DimsListToString(input.OriginalShape()) + ". " + @@ -1183,55 +1184,8 @@ InferenceRequest::Normalize() // (prepend 4 bytes to specify string length), so need to add all the // first 4 bytes for each element to find expected byte size if (data_type == inference::DataType::TYPE_STRING) { - int64_t element_count = triton::common::GetElementCount(input_dims); - int64_t element_idx = 0; - expected_byte_size = 0; - for (size_t i = 0; i < input.Data()->BufferCount(); ++i) { - size_t content_byte_size; - TRITONSERVER_MemoryType content_memory_type; - int64_t content_memory_id; - const char* content = input.Data()->BufferAt( - i, &content_byte_size, &content_memory_type, &content_memory_id); - - while (content_byte_size >= sizeof(uint32_t)) { - if (element_idx >= element_count) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "unexpected number of string elements " + - std::to_string(element_idx + 1) + - " for inference input '" + pr.first + "', expecting " + - std::to_string(element_count)); - } - - const uint32_t len = *(reinterpret_cast(content)); - content += sizeof(uint32_t); - content_byte_size -= sizeof(uint32_t); - expected_byte_size += sizeof(uint32_t); - - if (content_byte_size < len) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + - "incomplete string data for inference input '" + - pr.first + "', expecting string of length " + - std::to_string(len) + " but only " + - std::to_string(content_byte_size) + " bytes available"); - } - - content += len; - content_byte_size -= len; - expected_byte_size += len; - element_idx++; - } - } - - if (element_idx != element_count) { - return Status( - Status::Code::INVALID_ARG, - LogRequest() + "expected " + std::to_string(element_count) + - " strings for inference input '" + pr.first + "', got " + - std::to_string(element_idx)); - } + RETURN_IF_ERROR( + ValidateBytesInputs(input_id, input, &expected_byte_size)); } else { expected_byte_size = triton::common::GetByteSize(data_type, input_dims); } @@ -1239,14 +1193,13 @@ InferenceRequest::Normalize() (static_cast(byte_size) != expected_byte_size)) { return Status( Status::Code::INVALID_ARG, - LogRequest() + "input byte size mismatch for input '" + pr.first + + LogRequest() + "input byte size mismatch for input '" + input_id + "' for model '" + ModelName() + "'. Expected " + std::to_string(expected_byte_size) + ", got " + std::to_string(byte_size)); } } } - return Status::Success; } @@ -1311,6 +1264,63 @@ InferenceRequest::ValidateRequestInputs() return Status::Success; } +Status +InferenceRequest::ValidateBytesInputs( + const std::string& input_id, const Input& input, + int64_t* const expected_byte_size) const +{ + const auto& input_dims = input.ShapeWithBatchDim(); + int64_t element_count = triton::common::GetElementCount(input_dims); + int64_t element_idx = 0; + *expected_byte_size = 0; + for (size_t i = 0; i < input.Data()->BufferCount(); ++i) { + size_t content_byte_size; + TRITONSERVER_MemoryType content_memory_type; + int64_t content_memory_id; + const char* content = input.Data()->BufferAt( + i, &content_byte_size, &content_memory_type, &content_memory_id); + + while (content_byte_size >= sizeof(uint32_t)) { + if (element_idx >= element_count) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "unexpected number of string elements " + + std::to_string(element_idx + 1) + " for inference input '" + + input_id + "', expecting " + std::to_string(element_count)); + } + + const uint32_t len = *(reinterpret_cast(content)); + content += sizeof(uint32_t); + content_byte_size -= sizeof(uint32_t); + *expected_byte_size += sizeof(uint32_t); + + if (content_byte_size < len) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "incomplete string data for inference input '" + + input_id + "', expecting string of length " + + std::to_string(len) + " but only " + + std::to_string(content_byte_size) + " bytes available"); + } + + content += len; + content_byte_size -= len; + *expected_byte_size += len; + element_idx++; + } + } + + if (element_idx != element_count) { + return Status( + Status::Code::INVALID_ARG, + LogRequest() + "expected " + std::to_string(element_count) + + " strings for inference input '" + input_id + "', got " + + std::to_string(element_idx)); + } + + return Status::Success; +} + #ifdef TRITON_ENABLE_STATS void InferenceRequest::ReportStatistics( diff --git a/src/infer_request.h b/src/infer_request.h index 7ddc52c7b..c97ef8039 100644 --- a/src/infer_request.h +++ b/src/infer_request.h @@ -747,6 +747,10 @@ class InferenceRequest { // Helper for validating Inputs Status ValidateRequestInputs(); + Status ValidateBytesInputs( + const std::string& input_id, const Input& input, + int64_t* const expected_byte_size) const; + // Helpers for pending request metrics void IncrementPendingRequestCount(); void DecrementPendingRequestCount();