Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate input shapes #350

Merged
merged 11 commits into from
May 21, 2024
94 changes: 90 additions & 4 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <algorithm>
#include <deque>
#include <string>

#include "constants.h"
#include "model.h"
Expand Down Expand Up @@ -1069,13 +1070,14 @@ InferenceRequest::Normalize()
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));

auto& input_id = pr.first;
auto& input = pr.second;
auto shape = input.MutableShape();

if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input '" + pr.first + "' data-type is '" +
LogRequest() + "inference input '" + input_id + "' data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', but model '" + ModelName() + "' expects '" +
Expand All @@ -1098,7 +1100,7 @@ InferenceRequest::Normalize()
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
pr.first + "' for model '" + ModelName() + "', got " +
input_id + "' for model '" + ModelName() + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
Expand Down Expand Up @@ -1127,7 +1129,7 @@ InferenceRequest::Normalize()
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + pr.first +
LogRequest() + "unexpected shape for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()) + ". " +
Expand Down Expand Up @@ -1169,8 +1171,35 @@ InferenceRequest::Normalize()
input.MutableShapeWithBatchDim()->push_back(d);
}
}
// Matching incoming request's shape and byte size to make sure the
// payload contains correct number of elements.
// Note: Since we're using normalized input.ShapeWithBatchDim() here,
// make sure that all the normalization is before the check.
{
const size_t& byte_size = input.Data()->TotalByteSize();
const auto& data_type = input.DType();
const auto& input_dims = input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
RETURN_IF_ERROR(
ValidateBytesInputs(input_id, input, &expected_byte_size));
} else {
expected_byte_size = triton::common::GetByteSize(data_type, input_dims);
}
if ((byte_size > INT_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" + input_id +
"' for model '" + ModelName() + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
}
}

return Status::Success;
}

Expand Down Expand Up @@ -1235,6 +1264,63 @@ InferenceRequest::ValidateRequestInputs()
return Status::Success;
}

Status
InferenceRequest::ValidateBytesInputs(
const std::string& input_id, const Input& input,
int64_t* const expected_byte_size) const
{
const auto& input_dims = input.ShapeWithBatchDim();
int64_t element_count = triton::common::GetElementCount(input_dims);
int64_t element_idx = 0;
*expected_byte_size = 0;
for (size_t i = 0; i < input.Data()->BufferCount(); ++i) {
size_t content_byte_size;
TRITONSERVER_MemoryType content_memory_type;
int64_t content_memory_id;
const char* content = input.Data()->BufferAt(
i, &content_byte_size, &content_memory_type, &content_memory_id);

while (content_byte_size >= sizeof(uint32_t)) {
if (element_idx >= element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_idx + 1) + " for inference input '" +
input_id + "', expecting " + std::to_string(element_count));
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
content += sizeof(uint32_t);
content_byte_size -= sizeof(uint32_t);
*expected_byte_size += sizeof(uint32_t);

if (content_byte_size < len) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "incomplete string data for inference input '" +
input_id + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(content_byte_size) + " bytes available");
}

content += len;
content_byte_size -= len;
*expected_byte_size += len;
element_idx++;
}
}

if (element_idx != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" strings for inference input '" + input_id + "', got " +
std::to_string(element_idx));
}

return Status::Success;
}

#ifdef TRITON_ENABLE_STATS
void
InferenceRequest::ReportStatistics(
Expand Down
4 changes: 4 additions & 0 deletions src/infer_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -747,6 +747,10 @@ class InferenceRequest {
// Helper for validating Inputs
Status ValidateRequestInputs();

Status ValidateBytesInputs(
const std::string& input_id, const Input& input,
int64_t* const expected_byte_size) const;

// Helpers for pending request metrics
void IncrementPendingRequestCount();
void DecrementPendingRequestCount();
Expand Down
Loading