Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Post training quantization support in TRTorch #44

Merged
merged 21 commits into from
Apr 25, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
dd443a6
feat(//core/quantization): skeleton of INT8 PTQ calibrator
narendasan Apr 3, 2020
8580106
Merge branch 'master' of https://github.com/NVIDIA/trtorch into ptq
narendasan Apr 8, 2020
676bf56
feat(//cpp/ptq/training): Training recipe for VGG16 Classifier on
narendasan Apr 16, 2020
f022dfe
feat(//cpp/api): Functional Dataloader based PTQ
narendasan Apr 22, 2020
2f86f84
feat(//cpp/api): Remove the extra includes in the API header
narendasan Apr 22, 2020
2dd1ba3
feat(//core/execution): Type checking for the executor, now is the
narendasan Apr 23, 2020
0050f0e
bug(//tests): Test to reproduce FP16 accuracy issue
narendasan Apr 23, 2020
fc70267
fix(//cpp/api): Remove unecessary destructor in ptq class
narendasan Apr 23, 2020
1b25542
feat(//cpp/api): Adding max batch size setting
narendasan Apr 23, 2020
4a8dc6e
test(//tests/modules): Add FP16 test to testsuite
narendasan Apr 23, 2020
5f36f47
feat(//cpp/ptq): Add a feature to the dataset to use less than the full
narendasan Apr 23, 2020
5c0d737
feat(/cpp/api): Working INT8 Calibrator, also resolves #41
narendasan Apr 24, 2020
8d22bdd
fix(//cpp/api): Better inital condition for the dataloader iterator to
narendasan Apr 24, 2020
3afd209
fix(//core/conversion): Check for calibrator before setting int8 mode
narendasan Apr 24, 2020
825be69
fix(//cpp/api): set a default for calibrator
narendasan Apr 24, 2020
b989c7f
fix(//cpp/ptq): remove some logging from ptq app
narendasan Apr 24, 2020
df74136
feat(//tests): New optional accuracy tests to check INT8 and FP16
narendasan Apr 24, 2020
6facbf1
docs(//cpp/ptq): READMEs and documentation for running the PTQ example
narendasan Apr 24, 2020
46bb485
docs(//cpp/ptq): Last comment cleanup
narendasan Apr 24, 2020
cd24f26
fix: Address issues in PR
narendasan Apr 25, 2020
54a24b3
fix(//cpp/ptq): Tracing model in eval mode wrecks accuracy in Libtorch
narendasan Apr 25, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,14 @@ experiments/
py/build/
py/tmp/
py/.eggs
.vscode/
.vscode/
.DS_Store
._DS_Store
*.pth
*.pyc
cpp/ptq/training/vgg16/data/*
*.bin
cpp/ptq/datasets/data/
tests/accuracy/datasets/data/*
._.DS_Store
*.tar.gz
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,19 @@ More Information / System Architecture:
...
auto compile_settings = trtorch::ExtraInfo(dims);
// FP16 execution
compile_settings.op_precision = torch::kHalf;
compile_settings.op_precision = torch::kFloat;
// Compile module
auto trt_mod = trtorch::CompileGraph(ts_mod, compile_settings);
// Run like normal
auto results = trt_mod.forward({in_tensor});
...
```

> Notes on running in lower precisions:
> - Set precision with extra_info.op_precision
> - The module should be left in FP32 before compilation (FP16 can support half tensor models)
> - In FP16 only input tensors should be converted to FP16, other precisions use FP32

## Platform Support

| Platform | Support |
Expand Down
2 changes: 1 addition & 1 deletion core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ cc_library(
"@libtorch//:libtorch",
"@tensorrt//:nvinfer"
],
alwayslink=True,
alwayslink=True,
)


Expand Down
40 changes: 21 additions & 19 deletions core/compiler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,24 +24,24 @@
namespace trtorch {
namespace core {

c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {
c10::FunctionSchema GenerateGraphSchema(torch::jit::script::Module mod, std::string method_name, std::shared_ptr<torch::jit::Graph>& g) {

std::vector<c10::Argument> args;
for (auto in : g->inputs()) {
args.push_back(c10::Argument(in->debugName(), in->type()));
}

std::vector<c10::Argument> returns;
for (auto out : g->outputs()) {
returns.push_back(c10::Argument(out->debugName(), out->type()));
}

return c10::FunctionSchema(method_name, method_name, args, returns);
}


void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit::Graph>& g, std::string& serialized_engine) {
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
execution::EngineID uid = execution::RegisterEngineFromSerializedEngine(serialized_engine);
auto schema = execution::GetEngineFunctionSchema(uid);
auto num_io = execution::GetEngineIO(uid);

Expand All @@ -53,14 +53,14 @@ void AddEngineToGraph(torch::jit::script::Module mod, std::shared_ptr<torch::jit
in_val->setType(c10::TensorType::get());
graph_inputs.push_back(in_val);
}

auto engine_node = g->create(c10::Symbol::fromQualString(schema.name()), torch::jit::ArrayRef<torch::jit::Value*>(graph_inputs), num_io.second);
g->block()->appendNode(engine_node);

for (auto o : engine_node->outputs()) {
g->registerOutput(o);
}

return;
}

Expand All @@ -69,48 +69,50 @@ bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod,
auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);

auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
LOG_DEBUG(*g << "(CheckMethodOperatorSupport)\n");

// Is this necessary?
lowering::LowerBlock(g->block());

return conversion::VerifyConverterSupportForBlock(g->block());
}

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name,
conversion::ExtraInfo cfg) {
ExtraInfo cfg) {
auto convert_cfg = std::move(cfg.convert_info);

auto g = mod.get_method(method_name).graph();
// Go through PyTorch Lowering to simplify graph and extract weight parameters
auto graph_and_parameters = torch::jit::LowerGraph(*g, mod._ivalue());

g = graph_and_parameters.first;

// Go through TRTorch Lowering to reformat graph to be conversion friendly
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
lowering::LowerGraph(g);

auto params = graph_and_parameters.second;
auto named_params = conversion::get_named_params(g->inputs(), params);
LOG_INFO(*g << "(CompileGraph)\n");

// Is this necessary?
lowering::LowerBlock(g->block());
auto engine = ConvertBlockToEngine(g->block(), cfg, named_params);
auto engine = ConvertBlockToEngine(g->block(), convert_cfg, named_params);
return std::move(engine);
}

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,
conversion::ExtraInfo cfg) {
ExtraInfo cfg) {
// TODO: Should be doing a functional transform but need PR #31978
// [jit] More robust mangling
// torch::jit::script::Module new_mod = mod.clone();
Expand All @@ -128,7 +130,7 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod,

return new_mod;
}

} // namespace core
} // namespace trtorch

13 changes: 10 additions & 3 deletions core/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,19 @@

namespace trtorch {
namespace core {

struct ExtraInfo {
ExtraInfo(std::vector<conversion::InputRange> input_ranges)
: convert_info(std::move(input_ranges)) {}
conversion::ConversionInfo convert_info;
};

bool CheckMethodOperatorSupport(const torch::jit::script::Module& mod, std::string method_name);

std::string ConvertGraphToTRTEngine(const torch::jit::script::Module& mod,
std::string method_name, conversion::ExtraInfo cfg);
std::string method_name, ExtraInfo cfg);

torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, conversion::ExtraInfo cfg);
torch::jit::script::Module CompileGraph(const torch::jit::script::Module& module, ExtraInfo cfg);

} // namespace core
} // namespace trtorch
} // namespace trtorch
8 changes: 4 additions & 4 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ void AddInputs(ConversionCtx* ctx,
"Expected dimension specifications for all input tensors" \
<< ", but found " << input_tensors.size() \
<< " input tensors and " \
<< input_dims.size() << "dimension specs (conversion.AddInputs)");
<< input_dims.size() << " dimension specs (conversion.AddInputs)");

auto profile = ctx->builder->createOptimizationProfile();

Expand Down Expand Up @@ -179,7 +179,7 @@ void AddParamsToCtxValueMap(ConversionCtx* ctx, GraphParams& params) {
}
}

void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
LOG_INFO(ctx->logger, "Converting Block");

auto inputs = b->inputs();
Expand Down Expand Up @@ -221,7 +221,7 @@ void ConvertBlockToNetDef(ConversionCtx* ctx, const torch::jit::Block* b, ExtraI
// a serialized TensorRT engine that can be deserialized and run

// Probably should consolidate these two functions
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params) {
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params) {
ConversionCtx ctx(build_info.engine_settings);
ConvertBlockToNetDef(&ctx, b, build_info, static_params);
std::string engine = ctx.SerializeEngine();
Expand All @@ -235,7 +235,7 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b) {
if (!OpSupported(n)) {
auto schema = n->maybeSchema();
TRTORCH_CHECK(schema, "Unable to get schema for Node " << util::node_info(n) \
<< " (conversion.AddLayer)");
<< " (conversion.VerifyCoverterSupportForBlock");
std::stringstream ss;
ss << *schema;
unsupported_ops.insert(ss.str());
Expand Down
6 changes: 3 additions & 3 deletions core/conversion/conversion.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ struct InputRange {
std::vector<int64_t> max_shape);
};

struct ExtraInfo {
struct ConversionInfo {
std::vector<InputRange> input_ranges;
BuilderSettings engine_settings;
ExtraInfo(std::vector<InputRange> input_ranges)
ConversionInfo(std::vector<InputRange> input_ranges)
: input_ranges(std::move(input_ranges)), engine_settings(BuilderSettings()) {}
};

Expand All @@ -43,7 +43,7 @@ GraphParams get_named_params(c10::ArrayRef<torch::jit::Value*> inputs, std::vect

// Converts a already lowered block (blocks with no sub blocks) to
// a serialized TensorRT engine that can be deserialized and run
std::string ConvertBlockToEngine(const torch::jit::Block* b, ExtraInfo build_info, GraphParams& static_params);
std::string ConvertBlockToEngine(const torch::jit::Block* b, ConversionInfo build_info, GraphParams& static_params);

bool OpSupported(const torch::jit::Node* n);

Expand Down
48 changes: 32 additions & 16 deletions core/conversion/conversionctx/ConversionCtx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,25 @@ namespace core {
namespace conversion {

std::ostream& operator<<(std::ostream& os, const BuilderSettings& s) {
os << "Settings requested for TensorRT engine:" \
<< "\n Operating Precision: " << s.op_precision \
<< "\n Make Refittable Engine: " << s.refit \
<< "\n Debuggable Engine: " << s.debug \
<< "\n Strict Type: " << s.strict_type \
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
<< "\n Max Workspace Size: " << s.workspace_size \
<< "\n Device Type: " << s.device \
<< "\n Engine Capability: " << s.capability;
os << "Settings requested for TensorRT engine:" \
<< "\n Operating Precision: " << s.op_precision \
<< "\n Make Refittable Engine: " << s.refit \
<< "\n Debuggable Engine: " << s.debug \
<< "\n Strict Type: " << s.strict_types \
<< "\n Allow GPU Fallback (if running on DLA): " << s.allow_gpu_fallback \
<< "\n Min Timing Iterations: " << s.num_min_timing_iters \
<< "\n Avg Timing Iterations: " << s.num_avg_timing_iters \
<< "\n Max Workspace Size: " << s.workspace_size;

if (s.max_batch_size != 0) {
os << "\n Max Batch Size: " << s.max_batch_size;
} else {
os << "\n Max Batch Size: Not set";
}

os << "\n Device Type: " << s.device \
<< "\n Engine Capability: " << s.capability \
<< "\n Calibrator Created: " << (s.calibrator != nullptr);
return os;
}

Expand All @@ -36,13 +44,17 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)

switch(settings.op_precision) {
case nvinfer1::DataType::kHALF:
TRTORCH_CHECK(builder->platformHasFastFp16(), "Requested inference in FP16 but platform does support FP16");
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
input_type = nvinfer1::DataType::kHALF;
break;
// case nvinfer1::DataType::kINT8:
// cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
// input_type = nvinfer1::DataType::kFLOAT;
// break;
case nvinfer1::DataType::kINT8:
TRTORCH_CHECK(builder->platformHasFastInt8(), "Requested inference in INT8 but platform does support INT8");
cfg->setFlag(nvinfer1::BuilderFlag::kINT8);
input_type = nvinfer1::DataType::kFLOAT;
TRTORCH_CHECK(settings.calibrator != nullptr, "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the ExtraInfo struct with your calibrator");
cfg->setInt8Calibrator(settings.calibrator);
break;
case nvinfer1::DataType::kFLOAT:
default:
input_type = nvinfer1::DataType::kFLOAT;
Expand All @@ -57,14 +69,18 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
cfg->setFlag(nvinfer1::BuilderFlag::kDEBUG);
}

if (settings.strict_type) {
if (settings.strict_types) {
cfg->setFlag(nvinfer1::BuilderFlag::kSTRICT_TYPES);
}

if (settings.allow_gpu_fallback) {
cfg->setFlag(nvinfer1::BuilderFlag::kGPU_FALLBACK);
}

if (settings.max_batch_size != 0) {
builder->setMaxBatchSize(settings.max_batch_size);
}

cfg->setMinTimingIterations(settings.num_min_timing_iters);
cfg->setAvgTimingIterations(settings.num_avg_timing_iters);
cfg->setMaxWorkspaceSize(settings.workspace_size);
Expand Down
4 changes: 3 additions & 1 deletion core/conversion/conversionctx/ConversionCtx.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,15 @@ struct BuilderSettings {
nvinfer1::DataType op_precision = nvinfer1::DataType::kFLOAT;
bool refit = false;
bool debug = false;
bool strict_type = false;
bool strict_types = false;
bool allow_gpu_fallback = true;
nvinfer1::DeviceType device = nvinfer1::DeviceType::kGPU;
nvinfer1::EngineCapability capability = nvinfer1::EngineCapability::kDEFAULT;
nvinfer1::IInt8Calibrator* calibrator = nullptr;
uint64_t num_min_timing_iters = 2;
uint64_t num_avg_timing_iters = 1;
uint64_t workspace_size = 0;
uint64_t max_batch_size = 0;

BuilderSettings() = default;
BuilderSettings(const BuilderSettings& other) = default;
Expand Down
3 changes: 2 additions & 1 deletion core/conversion/converters/impl/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ volatile auto batch_norm_registrations = RegisterNodeConversionPatterns()
auto gamma = args[1].unwrapToTensor();

if (/*training*/ args[5].unwrapToBool()) {
LOG_WARNING("TensorRT only converts forward pass of graphs, but saw training = True, may see undefined behavior, consider placing module in eval mode");
LOG_WARNING(R"WARN(TRTorch only converts forward pass of graphs, but saw training = True, may see
unexpected behavior, consider placing module in eval mode before exporting the TorchScript module)WARN");
}

// If gamma is None this fails
Expand Down
9 changes: 3 additions & 6 deletions core/conversion/converters/impl/pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,20 +79,17 @@ auto pooling_registrations = RegisterNodeConversionPatterns()
for (size_t i = 0; i < out_shape.size(); i++) {
stride[(stride.size() - 1) - i] = in_shape[(in_shape.size() - 1) - i] / out_shape[(out_shape.size() - 1) - i];
}
LOG_DEBUG("Stride" << util::toDims(stride));
LOG_DEBUG("Stride: " << util::toDims(stride));

std::vector<int64_t> window(out_shape.size());
for (size_t i = 0; i < out_shape.size(); i++) {
window[window.size() - 1 - i] = in_shape[in_shape.size() - 1 - i] - (out_shape[out_shape.size() - 1 - i] - 1) * stride[stride.size() - 1 - i];
}

LOG_DEBUG("Window" << util::toDims(window));
LOG_DEBUG("Window: " << util::toDims(window));

auto new_layer = ctx->net->addPoolingNd(*in, nvinfer1::PoolingType::kAVERAGE, util::toDims(window));
if (!new_layer) {
LOG_ERROR("Unable to create average pooling layer from node: " << *n);
return false;
}
TRTORCH_CHECK(new_layer, "Unable to create average pooling layer from node: " << *n);

new_layer->setStrideNd(util::toDims(stride));

Expand Down
Loading