Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

validate input shapes #350

Merged
merged 11 commits into from
May 21, 2024
74 changes: 74 additions & 0 deletions src/infer_request.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <algorithm>
#include <deque>
#include <string>

#include "constants.h"
#include "model.h"
Expand Down Expand Up @@ -1169,6 +1170,79 @@ InferenceRequest::Normalize()
input.MutableShapeWithBatchDim()->push_back(d);
}
}
// Matching incoming request's shape and byte size to make sure the
// payload contains correct number of elements.
// Note: Since we're using normalized input.ShapeWithBatchDim() here,
// make sure that all the normalization is before the check.
{
const size_t& byte_size = input.Data()->TotalByteSize();
const auto& data_type = input.DType();
const auto& input_dims = input.ShapeWithBatchDim();
int64_t expected_byte_size = INT_MAX;
// Because Triton expects STRING type to be in special format
// (prepend 4 bytes to specify string length), so need to add all the
// first 4 bytes for each element to find expected byte size
if (data_type == inference::DataType::TYPE_STRING) {
int64_t element_count = triton::common::GetElementCount(input_dims);
jbkyang-nvi marked this conversation as resolved.
Show resolved Hide resolved
int64_t element_idx = 0;
expected_byte_size = 0;
for (size_t i = 0; i < input.Data()->BufferCount(); ++i) {
BufferAttributes* buffer_attributes;
const char* content = input.Data()->BufferAt(i, &buffer_attributes);
size_t content_byte_size = input.Data()->TotalByteSize();

while (content_byte_size >= sizeof(uint32_t)) {
jbkyang-nvi marked this conversation as resolved.
Show resolved Hide resolved
if (element_idx >= element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected number of string elements " +
std::to_string(element_idx + 1) +
" for inference input '" + pr.first + "', expecting " +
std::to_string(element_count));
}

const uint32_t len = *(reinterpret_cast<const uint32_t*>(content));
content += sizeof(uint32_t);
content_byte_size -= sizeof(uint32_t);
expected_byte_size += sizeof(uint32_t);

if (content_byte_size < len) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"incomplete string data for inference input '" +
pr.first + "', expecting string of length " +
std::to_string(len) + " but only " +
std::to_string(content_byte_size) + " bytes available");
}

content += len;
content_byte_size -= len;
expected_byte_size += len;
element_idx++;
}
}

if (element_idx != element_count) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " + std::to_string(element_count) +
" strings for inference input '" + pr.first + "', got " +
std::to_string(element_idx));
}
} else {
expected_byte_size = triton::common::GetByteSize(data_type, input_dims);
}
if ((byte_size > INT_MAX) ||
(static_cast<int64_t>(byte_size) != expected_byte_size)) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input byte size mismatch for input '" + pr.first +
"' for model '" + ModelName() + "'. Expected " +
std::to_string(expected_byte_size) + ", got " +
std::to_string(byte_size));
}
}
}

return Status::Success;
Expand Down
Loading