Skip to content

Commit

Permalink
Update shared memory bound checking for infer requests (#565)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbkyang-nvi committed Apr 12, 2024
1 parent 22796ae commit 90d3414
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,9 @@ SharedMemoryManager::RegisterSystemMemory(

Error
SharedMemoryManager::GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id)
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id)
{
// protect shared_memory_map_ from concurrent access
std::lock_guard<std::mutex> lock(mu_);
Expand All @@ -100,6 +101,29 @@ SharedMemoryManager::GetMemoryInfo(
return Error(
std::string("Unable to find shared memory region: '" + name + "'"));
}

// validate offset
size_t shm_region_end = 0;
if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
shm_region_end = it->second->offset_;
}
if (it->second->byte_size_ > 0) {
shm_region_end += it->second->byte_size_ - 1;
}
if (offset > shm_region_end) {
return Error(
std::string("Invalid offset for shared memory region: '" + name + "'")
.c_str());
}
// validate byte_size + offset is within memory bounds
size_t total_req_shm = offset + byte_size - 1;
if (total_req_shm > shm_region_end) {
return Error(std::string(
"Invalid offset + byte size for shared memory region: '" +
name + "'")
.c_str());
}

if (it->second->kind_ == TRITONSERVER_MEMORY_CPU) {
*shm_mapped_addr = (void*)((uint8_t*)it->second->mapped_addr_ +
it->second->offset_ + offset);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,19 +70,22 @@ class SharedMemoryManager {
Error RegisterSystemMemory(
const std::string& name, void* ptr, const size_t byte_size);

/// Get the access information for the shared memory block with the specified
/// name. Return an Error if named block doesn't exist.
/// Get the access information for the shared memory block
/// with the specified name. Return an Error
/// if named block doesn't exist.
/// \param name The name of the shared memory block to get.
/// \param offset The offset in the block
/// \param byte_size The byte size to request for the shm region
/// \param shm_mapped_addr Returns the pointer to the shared
/// memory block with the specified name and offset
/// \param memory_type Returns the type of the memory
/// \param device_id Returns the device id associated with the
/// memory block
/// \return an Error indicating success or failure.
Error GetMemoryInfo(
const std::string& name, size_t offset, void** shm_mapped_addr,
TRITONSERVER_MemoryType* memory_type, int64_t* device_id);
const std::string& name, size_t offset, size_t byte_size,
void** shm_mapped_addr, TRITONSERVER_MemoryType* memory_type,
int64_t* device_id);

/// Removes the named shared memory block of the specified type from
/// the manager. Any future attempt to get the details of this block
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -945,7 +945,8 @@ TritonLoader::Infer(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERROR(shm_manager_->GetMemoryInfo(
shm_name, offset, &buf, &memory_type, &memory_type_id));
shm_name, offset, shm_byte_size, &buf, &memory_type,
&memory_type_id));

alloc_payload.output_map_.emplace(
std::piecewise_construct, std::forward_as_tuple(output->Name()),
Expand Down Expand Up @@ -1149,7 +1150,8 @@ TritonLoader::AddInputs(
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
RETURN_IF_ERROR(shm_manager_->GetMemoryInfo(
shm_name, offset, &buf, &memory_type, &memory_type_id));
shm_name, offset, shm_byte_size, &buf, &memory_type,
&memory_type_id));
RETURN_IF_TRITONSERVER_ERROR(
inference_request_append_input_data_fn_(
irequest, input_name, buf, byte_size,
Expand Down

0 comments on commit 90d3414

Please sign in to comment.