Skip to content

Commit

Permalink
fix: Avoid resolving non-tensor inputs to torch segment_blocks unnecc…
Browse files Browse the repository at this point in the history
…essarily

Signed-off-by: Michael Feliz <[email protected]>
  • Loading branch information
mfeliz-cruise committed May 3, 2022
1 parent 10b55d4 commit 3e090ee
Show file tree
Hide file tree
Showing 2 changed files with 143 additions and 29 deletions.
62 changes: 33 additions & 29 deletions core/partitioning/partitioning.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,17 +137,8 @@ std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock
return std::pair<std::unordered_map<torch::jit::Value*, SegmentedBlock>, SegmentedBlock>(append_blocks, trt_block);
}

PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> nontensor_inputs;
// Gather all non-tensor inputs for this seg_block
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
nontensor_inputs.push_back(input);
}
}

std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
PartitionedGraph segmentBlocksWithSpecifiedInputs(SegmentedBlock& seg_block, std::vector<torch::jit::Value*> inputs_to_resolve){
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(inputs_to_resolve);
PartitionedGraph new_seg_blocks;
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
Expand All @@ -162,15 +153,15 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
}
} else {
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
std::unordered_set<torch::jit::Value*> inputs_to_resolve_set(inputs_to_resolve.begin(), inputs_to_resolve.end());
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end());

bool prev_non_tensor_outputs = false;
for (auto n : seg_block.raw_nodes()) {
// Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
// In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
// SegmentedBlock.
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
if (containTargetInputs(n, inputs_to_resolve_set) || prev_non_tensor_outputs) {
// If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
// TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
if (!tensorrt_nodes.empty()) {
Expand Down Expand Up @@ -201,6 +192,18 @@ PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
return new_seg_blocks;
}

PartitionedGraph segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
// reconstruct segmented_block if this block requires nonTensor input
std::vector<torch::jit::Value*> inputs_to_resolve;
// Gather all non-tensor inputs for this block
for (auto input : seg_block.raw_inputs()) {
if (!isTensorOrTensorList(input)) {
inputs_to_resolve.push_back(input);
}
}
return segmentBlocksWithSpecifiedInputs(seg_block, inputs_to_resolve);
}

std::unordered_map<torch::jit::Value*, usage_info> getInputUsageCounts(
const PartitionedGraph& segmented_blocks,
const std::function<bool(torch::jit::Value*)>& condition) {
Expand Down Expand Up @@ -248,6 +251,9 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
segmented_blocks, [](torch::jit::Value* input) -> bool { return !isTensorOrTensorList(input); });
auto idx_to_iter = getIdxtoIterMap(segmented_blocks_list);

std::map<int, std::vector<torch::jit::Value*>> torch_values_to_fix; //Only need to resolve values generated by tensorrt
std::set<int> tensorrt_blocks_to_fix; //Need to resolve ALL non-tensor inputs

// update blocks_list
std::unordered_set<int> updated_segments;
for (auto& use : usage_counts) {
Expand All @@ -256,27 +262,25 @@ void resolveNonTensorInputBlocks(PartitionedGraph& segmented_blocks) {
// kTorch segment.
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
auto first_torch_id = use_info.torch_use_id.back();
if (!updated_segments.count(first_torch_id)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// Torch-TensorRT doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[first_torch_id]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(first_torch_id);
}
torch_values_to_fix[first_torch_id].push_back(use.first);
}
// kTensorRT segments always need to inject nodes for the nonTensor inputs
for (auto i : use_info.tensorrt_use_id) {
if (!updated_segments.count(i)) {
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
// Torch-TensorRT doesn't support non-tensor inputs for a module.
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
updated_segments.insert(i);
}
tensorrt_blocks_to_fix.insert(i);
}
}
for(auto torch_block_pair : torch_values_to_fix){
auto to_inject_blocks = segmentBlocksWithSpecifiedInputs(segmented_blocks[torch_block_pair.first], torch_block_pair.second);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[torch_block_pair.first]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
}

for(auto i : tensorrt_blocks_to_fix){
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
}

segmented_blocks.clear();
segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end());
return;
Expand Down
110 changes: 110 additions & 0 deletions tests/core/partitioning/test_resolve_nontensor_inputs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,113 @@ TEST(Partitioning, ConvertForTensorListInputsInFallbackCorrectly) {
int count = count_trt_engines(fallback_g);
ASSERT_TRUE(count == 2);
}

TEST(Partitioning, ResolveOnlyNeccessaryNonTensorInputs) {
/* parseIR does not support "= aten::_set_item" so we will build this graph manually
const auto graph = R"IR(
graph(%x : Tensor,
%y : Tensor):
%2 : str = prim::Constant[value="INS"]()
%3 : str = prim::Constant[value="OUTS"]()
%4 : bool = prim::Constant[value=0]()
%5 : int = prim::Constant[value=-1]()
%6 : Dict(str, Tensor) = prim::DictConstruct()
= aten::_set_item(%6, %2, %x)
%7 : Tensor = aten::__getitem__(%6, %2)
%8 : Tensor = aten::lt(%7, %y)
%9 : Tensor?[] = prim::ListConstruct(%8)
%10 : int = prim::dtype(%7)
%11 : Device = prim::device(%7)
%12 : Tensor = aten::tensor(%5, %10, %11, %4)
%13 : Tensor = aten::index_put_(%7, %9, %12, %4)
= aten::_set_item(%6, %3, %7)
%14 : Tensor = aten::__getitem__(%6, %2)
%15 : Tensor = aten::__getitem__(%6, %3)
return (%14, %15))IR";
*/
auto g = std::make_shared<torch::jit::Graph>();
auto x = g->insertInput(0, "x");
auto y = g->insertInput(1, "y");
torch::jit::IValue ins_key("INS");
auto ins_key_val = g->insertConstant(ins_key);
torch::jit::IValue outs_key("OUTS");
auto outs_key_val = g->insertConstant(outs_key);
torch::jit::IValue zero(0);
auto false_const_val = g->insertConstant(zero);
false_const_val->setType(c10::BoolType::get());
torch::jit::IValue neg_one(-1);
auto neg_one_const_val = g->insertConstant(neg_one);
auto dict_node = g->createDict(ins_key_val->type(), x->type(), torch::jit::ArrayRef<torch::jit::Value*>(), torch::jit::ArrayRef<torch::jit::Value*>());
g->insertNode(dict_node);
auto set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val, x}, 0);
g->insertNode(set_node);
auto get_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val}, 1);
g->insertNode(get_node);
auto lt_node = g->create(torch::jit::Symbol::fromQualString("aten::lt"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), y}, 1);
g->insertNode(lt_node);
auto list_node = g->createList(at::OptionalType::create(lt_node->output()->type()), torch::jit::ArrayRef<torch::jit::Value*>{lt_node->output()});
g->insertNode(list_node);
auto dtype_node = g->create(torch::jit::Symbol::fromQualString("prim::dtype"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()}, 1);
dtype_node->output()->setType(neg_one_const_val->type());
g->insertNode(dtype_node);
auto device_node = g->create(torch::jit::Symbol::fromQualString("prim::device"), torch::jit::ArrayRef<torch::jit::Value*>{get_node->output()}, 1);
device_node->output()->setType(c10::DeviceObjType::get());
g->insertNode(device_node);
auto tensor_node = g->create(torch::jit::Symbol::fromQualString("aten::tensor"), torch::jit::ArrayRef<torch::jit::Value*>{neg_one_const_val, dtype_node->output(), device_node->output(), false_const_val}, 1);
g->insertNode(tensor_node);
auto index_put_node = g->create(torch::jit::Symbol::fromQualString("aten::index_put_"),
torch::jit::ArrayRef<torch::jit::Value*>{get_node->output(), list_node->output(), tensor_node->output(), false_const_val}, 1);
g->insertNode(index_put_node);
auto out_set_node = g->create(torch::jit::Symbol::fromQualString("aten::_set_item"),
torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val, get_node->output()}, 0);
g->insertNode(out_set_node);
auto get_ins_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), ins_key_val}, 1);
g->insertNode(get_ins_node);
auto get_outs_node = g->create(torch::jit::Symbol::fromQualString("aten::__getitem__"), torch::jit::ArrayRef<torch::jit::Value*>{dict_node->output(), outs_key_val}, 1);
g->insertNode(get_outs_node);
g->registerOutput(get_ins_node->output());
g->registerOutput(get_outs_node->output());

torch_tensorrt::core::partitioning::PartitionInfo partition_info;
partition_info.enabled = true;
std::vector<torch_tensorrt::core::ir::Input> inputs;
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));
inputs.push_back(torch_tensorrt::core::ir::Input({4, 4}));

std::unordered_map<const torch::jit::Value*, torch_tensorrt::core::ir::Input> inputs_map;
std::unordered_map<const torch::jit::Value*, c10::optional<at::ScalarType>> input_types;
for (size_t i = 0; i < g->inputs().size(); ++i) {
inputs_map.insert({g->inputs()[i], inputs[i]});
input_types.insert({g->inputs()[i], {at::kFloat}});
}
auto input_ivalues_map = torch_tensorrt::core::partitioning::generateRandomInputs(inputs_map, input_types);
auto segmented_blocks =
torch_tensorrt::core::partitioning::Partition(g->block(), input_ivalues_map, partition_info);

int torch_block_cnt = 0, trt_block_cnt = 0;
for (const auto& segmented_block : segmented_blocks) {
if (segmented_block.target() == torch_tensorrt::core::partitioning::SegmentedBlock::kTensorRT) {
++trt_block_cnt;
ASSERT_TRUE(checkSegmentedBlockInputType(segmented_block, [](torch::jit::TypePtr type_ptr) {
return type_ptr->isSubtypeOf(torch::jit::TensorType::get());
}));
} else {
++torch_block_cnt;
bool output_dict = false;
bool input_dict = false;
auto dict_type = dict_node->output()->type();
for (auto in : segmented_block.raw_inputs()) {
if(in->type()->isSubtypeOf(dict_type)){
input_dict = true;
}
}
for (auto out : segmented_block.raw_outputs()) {
if(out->type()->isSubtypeOf(dict_type)){
output_dict = true;
}
}
EXPECT_TRUE(output_dict ^ input_dict);
}
}
ASSERT_TRUE(trt_block_cnt == 1 && torch_block_cnt == 2);
}

0 comments on commit 3e090ee

Please sign in to comment.