From d8f5d29bb917728f2ded8c086c1c60cfb7ccd85c Mon Sep 17 00:00:00 2001 From: Dheeraj Peri Date: Fri, 30 Jul 2021 15:40:25 -0700 Subject: [PATCH] feat(//cpp/int8/qat): QAT application release Signed-off-by: Dheeraj Peri --- cpp/int8/benchmark/BUILD | 17 ++ cpp/int8/benchmark/benchmark.cpp | 73 +++++ cpp/int8/benchmark/benchmark.h | 4 + cpp/int8/benchmark/timer.h | 44 +++ cpp/int8/datasets/BUILD | 14 + cpp/int8/datasets/cifar10.cpp | 137 ++++++++++ cpp/int8/datasets/cifar10.h | 45 ++++ cpp/int8/ptq/BUILD | 22 ++ cpp/int8/ptq/README.md | 159 +++++++++++ cpp/int8/ptq/main.cpp | 152 +++++++++++ cpp/int8/qat/BUILD | 22 ++ cpp/int8/qat/README.md | 159 +++++++++++ cpp/int8/qat/main.cpp | 139 ++++++++++ cpp/int8/training/vgg16/README.md | 39 +++ cpp/int8/training/vgg16/export_ckpt.py | 79 ++++++ cpp/int8/training/vgg16/export_qat.py | 103 +++++++ cpp/int8/training/vgg16/main.py | 224 +++++++++++++++ cpp/int8/training/vgg16/requirements.txt | 2 + cpp/int8/training/vgg16/train_qat.py | 329 +++++++++++++++++++++++ cpp/int8/training/vgg16/vgg16.py | 56 ++++ 20 files changed, 1819 insertions(+) create mode 100644 cpp/int8/benchmark/BUILD create mode 100644 cpp/int8/benchmark/benchmark.cpp create mode 100644 cpp/int8/benchmark/benchmark.h create mode 100644 cpp/int8/benchmark/timer.h create mode 100644 cpp/int8/datasets/BUILD create mode 100644 cpp/int8/datasets/cifar10.cpp create mode 100644 cpp/int8/datasets/cifar10.h create mode 100644 cpp/int8/ptq/BUILD create mode 100644 cpp/int8/ptq/README.md create mode 100644 cpp/int8/ptq/main.cpp create mode 100644 cpp/int8/qat/BUILD create mode 100644 cpp/int8/qat/README.md create mode 100644 cpp/int8/qat/main.cpp create mode 100644 cpp/int8/training/vgg16/README.md create mode 100644 cpp/int8/training/vgg16/export_ckpt.py create mode 100644 cpp/int8/training/vgg16/export_qat.py create mode 100644 cpp/int8/training/vgg16/main.py create mode 100644 cpp/int8/training/vgg16/requirements.txt create mode 100644 cpp/int8/training/vgg16/train_qat.py create mode 100644 cpp/int8/training/vgg16/vgg16.py diff --git a/cpp/int8/benchmark/BUILD b/cpp/int8/benchmark/BUILD new file mode 100644 index 0000000000..bd83d731c3 --- /dev/null +++ b/cpp/int8/benchmark/BUILD @@ -0,0 +1,17 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "benchmark", + srcs = [ + "benchmark.cpp", + "timer.h", + ], + hdrs = [ + "benchmark.h", + ], + deps = [ + "//cpp/api:trtorch", + "@libtorch", + "@libtorch//:caffe2", + ], +) diff --git a/cpp/int8/benchmark/benchmark.cpp b/cpp/int8/benchmark/benchmark.cpp new file mode 100644 index 0000000000..66c6129487 --- /dev/null +++ b/cpp/int8/benchmark/benchmark.cpp @@ -0,0 +1,73 @@ +#include "ATen/Context.h" +#include "c10/cuda/CUDACachingAllocator.h" +#include "cuda_runtime_api.h" +#include "torch/script.h" +#include "torch/torch.h" +#include "trtorch/trtorch.h" + +#include "timer.h" + +#define NUM_WARMUP_RUNS 20 +#define NUM_RUNS 100 + +// Benchmaking code +void print_avg_std_dev(std::string type, std::vector& runtimes, uint64_t batch_size) { + float avg_runtime = std::accumulate(runtimes.begin(), runtimes.end(), 0.0) / runtimes.size(); + float fps = (1000.f / avg_runtime) * batch_size; + std::cout << "[" << type << "]: batch_size: " << batch_size << "\n Average latency: " << avg_runtime + << " ms\n Average FPS: " << fps << " fps" << std::endl; + + std::vector rt_diff(runtimes.size()); + std::transform(runtimes.begin(), runtimes.end(), rt_diff.begin(), [avg_runtime](float x) { return x - avg_runtime; }); + float rt_sq_sum = std::inner_product(rt_diff.begin(), rt_diff.end(), rt_diff.begin(), 0.0); + float rt_std_dev = std::sqrt(rt_sq_sum / runtimes.size()); + + std::vector fps_diff(runtimes.size()); + std::transform(runtimes.begin(), runtimes.end(), fps_diff.begin(), [fps, batch_size](float x) { + return ((1000.f / x) * batch_size) - fps; + }); + float fps_sq_sum = std::inner_product(fps_diff.begin(), fps_diff.end(), fps_diff.begin(), 0.0); + float fps_std_dev = std::sqrt(fps_sq_sum / runtimes.size()); + std::cout << " Latency Standard Deviation: " << rt_std_dev << "\n FPS Standard Deviation: " << fps_std_dev + << "\n(excluding initial warmup runs)" << std::endl; +} + +std::vector benchmark_module(torch::jit::script::Module& mod, std::vector shape) { + auto execution_timer = timers::PreciseCPUTimer(); + std::vector execution_runtimes; + + for (uint64_t i = 0; i < NUM_WARMUP_RUNS; i++) { + std::vector inputs_ivalues; + auto in = at::rand(shape, {at::kCUDA}); +#ifdef HALF + in = in.to(torch::kHalf); +#endif + inputs_ivalues.push_back(in.clone()); + + cudaDeviceSynchronize(); + mod.forward(inputs_ivalues); + cudaDeviceSynchronize(); + } + + for (uint64_t i = 0; i < NUM_RUNS; i++) { + std::vector inputs_ivalues; + auto in = at::rand(shape, {at::kCUDA}); +#ifdef HALF + in = in.to(torch::kHalf); +#endif + inputs_ivalues.push_back(in.clone()); + cudaDeviceSynchronize(); + + execution_timer.start(); + mod.forward(inputs_ivalues); + cudaDeviceSynchronize(); + execution_timer.stop(); + + auto time = execution_timer.milliseconds(); + execution_timer.reset(); + execution_runtimes.push_back(time); + + c10::cuda::CUDACachingAllocator::emptyCache(); + } + return execution_runtimes; +} diff --git a/cpp/int8/benchmark/benchmark.h b/cpp/int8/benchmark/benchmark.h new file mode 100644 index 0000000000..3c11833ab3 --- /dev/null +++ b/cpp/int8/benchmark/benchmark.h @@ -0,0 +1,4 @@ +#pragma once + +void print_avg_std_dev(std::string type, std::vector& runtimes, uint64_t batch_size); +std::vector benchmark_module(torch::jit::script::Module& mod, std::vector shape); diff --git a/cpp/int8/benchmark/timer.h b/cpp/int8/benchmark/timer.h new file mode 100644 index 0000000000..72ef142671 --- /dev/null +++ b/cpp/int8/benchmark/timer.h @@ -0,0 +1,44 @@ +#pragma once +#include + +namespace timers { +class TimerBase { + public: + virtual void start() {} + virtual void stop() {} + float microseconds() const noexcept { + return mMs * 1000.f; + } + float milliseconds() const noexcept { + return mMs; + } + float seconds() const noexcept { + return mMs / 1000.f; + } + void reset() noexcept { + mMs = 0.f; + } + + protected: + float mMs{0.0f}; +}; + +template +class CPUTimer : public TimerBase { + public: + using clock_type = Clock; + + void start() { + mStart = Clock::now(); + } + void stop() { + mStop = Clock::now(); + mMs += std::chrono::duration{mStop - mStart}.count(); + } + + private: + std::chrono::time_point mStart, mStop; +}; // class CPUTimer + +using PreciseCPUTimer = CPUTimer; +} // namespace timers diff --git a/cpp/int8/datasets/BUILD b/cpp/int8/datasets/BUILD new file mode 100644 index 0000000000..f2e560b3f3 --- /dev/null +++ b/cpp/int8/datasets/BUILD @@ -0,0 +1,14 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "cifar10", + srcs = [ + "cifar10.cpp", + ], + hdrs = [ + "cifar10.h", + ], + deps = [ + "@libtorch", + ], +) diff --git a/cpp/int8/datasets/cifar10.cpp b/cpp/int8/datasets/cifar10.cpp new file mode 100644 index 0000000000..161a9989cf --- /dev/null +++ b/cpp/int8/datasets/cifar10.cpp @@ -0,0 +1,137 @@ +// #include "cpp/int8/ptq/datasets/cifar10.h" +#include "cifar10.h" +#include "torch/data/example.h" +#include "torch/torch.h" +#include "torch/types.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace datasets { +namespace { +constexpr const char* kTrainFilenamePrefix = "data_batch_"; +constexpr const uint32_t kNumTrainFiles = 5; +constexpr const char* kTestFilename = "test_batch.bin"; +constexpr const size_t kLabelSize = 1; // B +constexpr const size_t kImageSize = 3072; // B +constexpr const size_t kImageDim = 32; +constexpr const size_t kImageChannels = 3; +constexpr const size_t kBatchSize = 10000; + +std::pair read_batch(const std::string& path) { + std::ifstream batch; + batch.open(path, std::ios::in | std::ios::binary | std::ios::ate); + + auto file_size = batch.tellg(); + std::unique_ptr buf(new char[file_size]); + + batch.seekg(0, std::ios::beg); + batch.read(buf.get(), file_size); + batch.close(); + + std::vector labels; + std::vector images; + labels.reserve(kBatchSize); + images.reserve(kBatchSize); + + for (size_t i = 0; i < kBatchSize; i++) { + uint8_t label = buf[i * (kImageSize + kLabelSize)]; + std::vector image; + image.reserve(kImageSize); + std::copy( + &buf[i * (kImageSize + kLabelSize) + 1], + &buf[i * (kImageSize + kLabelSize) + kImageSize], + std::back_inserter(image)); + labels.push_back(label); + auto image_tensor = + torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8)) + .to(torch::kF32); + images.push_back(image_tensor); + } + + auto labels_tensor = + torch::from_blob(labels.data(), {kBatchSize}, torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); + assert(labels_tensor.size(0) == kBatchSize); + + auto images_tensor = torch::stack(images); + assert(images_tensor.size(0) == kBatchSize); + + return std::make_pair(images_tensor, labels_tensor); +} + +std::pair read_train_data(const std::string& root) { + std::vector images, targets; + for (uint32_t i = 1; i <= 5; i++) { + std::stringstream ss; + ss << root << '/' << kTrainFilenamePrefix << i << ".bin"; + auto batch = read_batch(ss.str()); + images.push_back(batch.first); + targets.push_back(batch.second); + } + + torch::Tensor image_tensor = + std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) { + return torch::cat({a, b}, 0); + }); + torch::Tensor target_tensor = + std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) { + return torch::cat({a, b}, 0); + }); + + return std::make_pair(image_tensor, target_tensor); +} + +std::pair read_test_data(const std::string& root) { + std::stringstream ss; + ss << root << '/' << kTestFilename; + return read_batch(ss.str()); +} +} // namespace + +CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) { + std::pair data; + if (mode_ == Mode::kTrain) { + data = read_train_data(root); + } else { + data = read_test_data(root); + } + + images_ = std::move(data.first); + targets_ = std::move(data.second); + assert(images_.sizes()[0] == images_.sizes()[0]); +} + +torch::data::Example<> CIFAR10::get(size_t index) { + return {images_[index], targets_[index]}; +} + +c10::optional CIFAR10::size() const { + return images_.size(0); +} + +bool CIFAR10::is_train() const noexcept { + return mode_ == Mode::kTrain; +} + +const torch::Tensor& CIFAR10::images() const { + return images_; +} + +const torch::Tensor& CIFAR10::targets() const { + return targets_; +} + +CIFAR10&& CIFAR10::use_subset(int64_t new_size) { + assert(new_size <= images_.sizes()[0]); + images_ = images_.slice(0, 0, new_size); + targets_ = targets_.slice(0, 0, new_size); + return std::move(*this); +} + +} // namespace datasets diff --git a/cpp/int8/datasets/cifar10.h b/cpp/int8/datasets/cifar10.h new file mode 100644 index 0000000000..71861b24c9 --- /dev/null +++ b/cpp/int8/datasets/cifar10.h @@ -0,0 +1,45 @@ +#pragma once + +#include "torch/data/datasets/base.h" +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include + +namespace datasets { +// The CIFAR10 Dataset +class CIFAR10 : public torch::data::datasets::Dataset { + public: + // The mode in which the dataset is loaded + enum class Mode { kTrain, kTest }; + + // Loads CIFAR10 from un-tarred file + // Dataset can be found + // https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz Root path should be + // the directory that contains the content of tarball + explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); + + // Returns the pair at index in the dataset + torch::data::Example<> get(size_t index) override; + + // The size of the dataset + c10::optional size() const override; + + // The mode the dataset is in + bool is_train() const noexcept; + + // Returns all images stacked into a single tensor + const torch::Tensor& images() const; + + // Returns all targets stacked into a single tensor + const torch::Tensor& targets() const; + + // Trims the dataset to the first n pairs + CIFAR10&& use_subset(int64_t new_size); + + private: + Mode mode_; + torch::Tensor images_, targets_; +}; +} // namespace datasets diff --git a/cpp/int8/ptq/BUILD b/cpp/int8/ptq/BUILD new file mode 100644 index 0000000000..6e79998dce --- /dev/null +++ b/cpp/int8/ptq/BUILD @@ -0,0 +1,22 @@ +package(default_visibility = ["//visibility:public"]) + +cc_binary( + name = "ptq", + srcs = [ + "main.cpp", + ], + copts = [ + "-pthread", + ], + linkopts = [ + "-lpthread", + ], + deps = [ + "//cpp/api:trtorch", + "//cpp/int8/benchmark", + "//cpp/int8/datasets:cifar10", + "@libtorch", + "@libtorch//:caffe2", + "@tensorrt//:nvinfer", + ], +) diff --git a/cpp/int8/ptq/README.md b/cpp/int8/ptq/README.md new file mode 100644 index 0000000000..7cb179cd64 --- /dev/null +++ b/cpp/int8/ptq/README.md @@ -0,0 +1,159 @@ +# ptq + +## How to create your own PTQ application + +Post Training Quantization (PTQ) is a technique to reduce the required computational resources for inference while still preserving the accuracy of your model by mapping the traditional FP32 activation space to a reduced INT8 space. TensorRT uses a calibration step which executes your model with sample data from the target domain and track the activations in FP32 to calibrate a mapping to INT8 that minimizes the information loss between FP32 inference and INT8 inference. + +Users writing TensorRT applications are required to setup a calibrator class which will provide sample data to the TensorRT calibrator. With TRTorch we look to leverage existing infrastructure in PyTorch to make implementing calibrators easier. + +LibTorch provides a `Dataloader` and `Dataset` API which steamlines preprocessing and batching input data. TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly implement a `torch::Dataset` for your target domain, place it in a Dataloader and create a INT8 Calibrator from it which you can provide to TRTorch to run INT8 Calibration during compliation of your module. + +### Code + +Here is an example interface of a `torch::Dataset` class for CIFAR10: + +```C++ +//cpp/ptq/datasets/cifar10.h +#pragma once + +#include "torch/data/datasets/base.h" +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include + +namespace datasets { +// The CIFAR10 Dataset +class CIFAR10 : public torch::data::datasets::Dataset { +public: + // The mode in which the dataset is loaded + enum class Mode { kTrain, kTest }; + + // Loads CIFAR10 from un-tarred file + // Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz + // Root path should be the directory that contains the content of tarball + explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); + + // Returns the pair at index in the dataset + torch::data::Example<> get(size_t index) override; + + // The size of the dataset + c10::optional size() const override; + + // The mode the dataset is in + bool is_train() const noexcept; + + // Returns all images stacked into a single tensor + const torch::Tensor& images() const; + + // Returns all targets stacked into a single tensor + const torch::Tensor& targets() const; + + // Trims the dataset to the first n pairs + CIFAR10&& use_subset(int64_t new_size); + + +private: + Mode mode_; + torch::Tensor images_, targets_; +}; +} // namespace datasets +``` + +This class's implementation reads from the binary distribution of the CIFAR10 dataset and builds two tensors which hold the images and labels. + +Then we select a subset of the dataset to use for calibration, since we don't need the the full dataset for calibration and calibration does take time, then define the preprocessing to apply to the images in the dataset and create a Dataloader from the dataset which will batch the data: + +```C++ +auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .use_subset(320) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); +auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), + torch::data::DataLoaderOptions().batch_size(32) + .workers(2)); +``` + +Next we create a calibrator from the `calibration_dataloader` using the calibrator factory: + +```C++ +auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); + +``` + +Here we also define a location to write a calibration cache file to which we can use to reuse the calibration data without needing the dataset and whether or not we should use the cache file if it exists. There also exists a `trtorch::ptq::make_int8_cache_calibrator` factory which creates a calibrator that uses the cache only for cases where you may do engine building on a machine that has limited storage (i.e. no space for a dataset) or to have a simpiler deployment application. + +The calibrator factories create a calibrator that inherits from a `nvinfer1::IInt8Calibrator` virtual class (`nvinfer1::IInt8EntropyCalibrator2` by default) which defines the calibration algorithm used when calibrating. You can explicitly make the selection of calibration algorithm like this: + +```C++ +// MinMax Calibrator is geared more towards NLP tasks +auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); +``` + +Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::CompileSpec` struct and compiling the module: + +```C++ + std::vector> input_shape = {{32, 3, 32, 32}}; + /// Configure settings for compilation + auto compile_spec = trtorch::CompileSpec({input_shape}); + /// Set enable INT8 precision + compile_spec.enabled_precisions.insert(torch::kI8); + /// Use the TensorRT Entropy Calibrator + compile_spec.ptq_calibrator = calibrator; + /// Set a larger workspace (you may get better performace from doing so) + compile_spec.workspace_size = 1 << 28; + + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); +``` + +If you have an existing Calibrator implementation for TensorRT you may directly set the `ptq_calibrator` field with a pointer to your calibrator and it will work as well. + +From here not much changes in terms of how to execution works. You are still able to fully use Libtorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`. + + +## Running the Example Application + +This is a short example application that shows how to use TRTorch to perform post-training quantization for a module. + +## Prerequisites + +1. Download CIFAR10 Dataset Binary version ([https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz)) +2. Train a network on CIFAR10 (see `training/` for a VGG16 recipie) +3. Export model to torchscript + +## Compilation + +``` shell +bazel build //cpp/ptq --compilation_mode=opt +``` + +If you want insight into what is going under the hood or need debug symbols + +``` shell +bazel build //cpp/ptq --compilation_mode=dbg +``` + +## Usage + +``` shell +ptq +``` + +## Example Output + +``` +Accuracy of JIT model on test set: 92.1% +Compiling and quantizing module +Accuracy of quantized model on test set: 91.0044% +Latency of JIT model FP32 (Batch Size 32): 1.73497ms +Latency of quantized model (Batch Size 32): 0.365737ms +``` + +## Citations + +``` +Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images. +Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. +``` diff --git a/cpp/int8/ptq/main.cpp b/cpp/int8/ptq/main.cpp new file mode 100644 index 0000000000..fddb4abcc6 --- /dev/null +++ b/cpp/int8/ptq/main.cpp @@ -0,0 +1,152 @@ +#include "torch/script.h" +#include "torch/torch.h" +#include "trtorch/ptq.h" +#include "trtorch/trtorch.h" + +#include "NvInfer.h" + +#include "cpp/int8/benchmark/benchmark.h" +#include "cpp/int8/datasets/cifar10.h" + +#include +#include +#include +#include + +namespace F = torch::nn::functional; + +// Actual PTQ application code +struct Resize : public torch::data::transforms::TensorTransform { + Resize(std::vector new_size) : new_size_(new_size) {} + + torch::Tensor operator()(torch::Tensor input) { + input = input.unsqueeze(0); + auto upsampled = + F::interpolate(input, F::InterpolateFuncOptions().size(new_size_).align_corners(false).mode(torch::kBilinear)); + return upsampled.squeeze(0); + } + + std::vector new_size_; +}; + +torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::Module& mod) { + auto calibration_dataset = + datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .use_subset(320) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto calibration_dataloader = torch::data::make_data_loader( + std::move(calibration_dataset), torch::data::DataLoaderOptions().batch_size(32).workers(2)); + + std::string calibration_cache_file = "/tmp/vgg16_TRT_ptq_calibration.cache"; + + auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); + + std::vector> input_shape = {{32, 3, 32, 32}}; + /// Configure settings for compilation + auto compile_spec = trtorch::CompileSpec({input_shape}); + /// Set operating precision to INT8 + compile_spec.enabled_precisions.insert(torch::kI8); + /// Use the TensorRT Entropy Calibrator + compile_spec.ptq_calibrator = calibrator; + /// Set max batch size for the engine + compile_spec.max_batch_size = 32; + /// Set a larger workspace + compile_spec.workspace_size = 1 << 28; + + mod.eval(); + +#ifdef SAVE_ENGINE + std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec); + std::ofstream out("/tmp/engine_converted_from_jit.trt"); + out << engine; + out.close(); +#endif + + std::cout << "Compiling and quantizing module" << std::endl; + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); + return std::move(trt_mod); +} + +int main(int argc, const char* argv[]) { + at::globalContext().setBenchmarkCuDNN(true); + + if (argc < 3) { + std::cerr << "usage: ptq \n"; + return -1; + } + + torch::jit::Module mod; + try { + /// Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(argv[1]); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return -1; + } + + /// Create the calibration dataset + const std::string data_dir = std::string(argv[2]); + auto trt_mod = compile_int8_model(data_dir, mod); + + /// Dataloader moved into calibrator so need another for inference + auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto eval_dataloader = torch::data::make_data_loader( + std::move(eval_dataset), torch::data::DataLoaderOptions().batch_size(32).workers(2)); + + /// Check the FP32 accuracy in JIT + float correct = 0.0, total = 0.0; + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + total += targets.sizes()[0]; + correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + } + std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl; + + /// Check the INT8 accuracy in TRT + correct = 0.0; + total = 0.0; + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + if (images.sizes()[0] < 32) { + /// To handle smaller batches util Optimization profiles work with Int8 + auto diff = 32 - images.sizes()[0]; + auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA}); + auto target_padding = torch::zeros({diff}, {torch::kCUDA}); + images = torch::cat({images, img_padding}, 0); + targets = torch::cat({targets, target_padding}, 0); + } + + auto outputs = trt_mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + predictions = predictions.reshape(predictions.sizes()[0]); + + if (predictions.sizes()[0] != targets.sizes()[0]) { + /// To handle smaller batches util Optimization profiles work with Int8 + predictions = predictions.slice(0, 0, targets.sizes()[0]); + } + + total += targets.sizes()[0]; + correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + } + std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl; + + /// Time execution in JIT-FP32 and TRT-INT8 + std::vector> dims = {{32, 3, 32, 32}}; + + auto jit_runtimes = benchmark_module(mod, dims[0]); + print_avg_std_dev("JIT model FP32", jit_runtimes, dims[0][0]); + + auto trt_runtimes = benchmark_module(trt_mod, dims[0]); + print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); +} diff --git a/cpp/int8/qat/BUILD b/cpp/int8/qat/BUILD new file mode 100644 index 0000000000..f322ab7385 --- /dev/null +++ b/cpp/int8/qat/BUILD @@ -0,0 +1,22 @@ +package(default_visibility = ["//visibility:public"]) + +cc_binary( + name = "qat", + srcs = [ + "main.cpp", + ], + copts = [ + "-pthread", + ], + linkopts = [ + "-lpthread", + ], + deps = [ + "//cpp/api:trtorch", + "//cpp/qat/benchmark", + "//cpp/qat/datasets:cifar10", + "@libtorch", + "@libtorch//:caffe2", + "@tensorrt//:nvinfer", + ], +) diff --git a/cpp/int8/qat/README.md b/cpp/int8/qat/README.md new file mode 100644 index 0000000000..7cb179cd64 --- /dev/null +++ b/cpp/int8/qat/README.md @@ -0,0 +1,159 @@ +# ptq + +## How to create your own PTQ application + +Post Training Quantization (PTQ) is a technique to reduce the required computational resources for inference while still preserving the accuracy of your model by mapping the traditional FP32 activation space to a reduced INT8 space. TensorRT uses a calibration step which executes your model with sample data from the target domain and track the activations in FP32 to calibrate a mapping to INT8 that minimizes the information loss between FP32 inference and INT8 inference. + +Users writing TensorRT applications are required to setup a calibrator class which will provide sample data to the TensorRT calibrator. With TRTorch we look to leverage existing infrastructure in PyTorch to make implementing calibrators easier. + +LibTorch provides a `Dataloader` and `Dataset` API which steamlines preprocessing and batching input data. TRTorch uses Dataloaders as the base of a generic calibrator implementation. So you will be able to reuse or quickly implement a `torch::Dataset` for your target domain, place it in a Dataloader and create a INT8 Calibrator from it which you can provide to TRTorch to run INT8 Calibration during compliation of your module. + +### Code + +Here is an example interface of a `torch::Dataset` class for CIFAR10: + +```C++ +//cpp/ptq/datasets/cifar10.h +#pragma once + +#include "torch/data/datasets/base.h" +#include "torch/data/example.h" +#include "torch/types.h" + +#include +#include + +namespace datasets { +// The CIFAR10 Dataset +class CIFAR10 : public torch::data::datasets::Dataset { +public: + // The mode in which the dataset is loaded + enum class Mode { kTrain, kTest }; + + // Loads CIFAR10 from un-tarred file + // Dataset can be found https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz + // Root path should be the directory that contains the content of tarball + explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); + + // Returns the pair at index in the dataset + torch::data::Example<> get(size_t index) override; + + // The size of the dataset + c10::optional size() const override; + + // The mode the dataset is in + bool is_train() const noexcept; + + // Returns all images stacked into a single tensor + const torch::Tensor& images() const; + + // Returns all targets stacked into a single tensor + const torch::Tensor& targets() const; + + // Trims the dataset to the first n pairs + CIFAR10&& use_subset(int64_t new_size); + + +private: + Mode mode_; + torch::Tensor images_, targets_; +}; +} // namespace datasets +``` + +This class's implementation reads from the binary distribution of the CIFAR10 dataset and builds two tensors which hold the images and labels. + +Then we select a subset of the dataset to use for calibration, since we don't need the the full dataset for calibration and calibration does take time, then define the preprocessing to apply to the images in the dataset and create a Dataloader from the dataset which will batch the data: + +```C++ +auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .use_subset(320) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, + {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); +auto calibration_dataloader = torch::data::make_data_loader(std::move(calibration_dataset), + torch::data::DataLoaderOptions().batch_size(32) + .workers(2)); +``` + +Next we create a calibrator from the `calibration_dataloader` using the calibrator factory: + +```C++ +auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); + +``` + +Here we also define a location to write a calibration cache file to which we can use to reuse the calibration data without needing the dataset and whether or not we should use the cache file if it exists. There also exists a `trtorch::ptq::make_int8_cache_calibrator` factory which creates a calibrator that uses the cache only for cases where you may do engine building on a machine that has limited storage (i.e. no space for a dataset) or to have a simpiler deployment application. + +The calibrator factories create a calibrator that inherits from a `nvinfer1::IInt8Calibrator` virtual class (`nvinfer1::IInt8EntropyCalibrator2` by default) which defines the calibration algorithm used when calibrating. You can explicitly make the selection of calibration algorithm like this: + +```C++ +// MinMax Calibrator is geared more towards NLP tasks +auto calibrator = trtorch::ptq::make_int8_calibrator(std::move(calibration_dataloader), calibration_cache_file, true); +``` + +Then all thats required to setup the module for INT8 calibration is to set the following compile settings in the `trtorch::CompileSpec` struct and compiling the module: + +```C++ + std::vector> input_shape = {{32, 3, 32, 32}}; + /// Configure settings for compilation + auto compile_spec = trtorch::CompileSpec({input_shape}); + /// Set enable INT8 precision + compile_spec.enabled_precisions.insert(torch::kI8); + /// Use the TensorRT Entropy Calibrator + compile_spec.ptq_calibrator = calibrator; + /// Set a larger workspace (you may get better performace from doing so) + compile_spec.workspace_size = 1 << 28; + + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); +``` + +If you have an existing Calibrator implementation for TensorRT you may directly set the `ptq_calibrator` field with a pointer to your calibrator and it will work as well. + +From here not much changes in terms of how to execution works. You are still able to fully use Libtorch as the sole interface for inference. Data should remain in FP32 precision when it's passed into `trt_mod.forward`. + + +## Running the Example Application + +This is a short example application that shows how to use TRTorch to perform post-training quantization for a module. + +## Prerequisites + +1. Download CIFAR10 Dataset Binary version ([https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz](https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz)) +2. Train a network on CIFAR10 (see `training/` for a VGG16 recipie) +3. Export model to torchscript + +## Compilation + +``` shell +bazel build //cpp/ptq --compilation_mode=opt +``` + +If you want insight into what is going under the hood or need debug symbols + +``` shell +bazel build //cpp/ptq --compilation_mode=dbg +``` + +## Usage + +``` shell +ptq +``` + +## Example Output + +``` +Accuracy of JIT model on test set: 92.1% +Compiling and quantizing module +Accuracy of quantized model on test set: 91.0044% +Latency of JIT model FP32 (Batch Size 32): 1.73497ms +Latency of quantized model (Batch Size 32): 0.365737ms +``` + +## Citations + +``` +Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images. +Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. +``` diff --git a/cpp/int8/qat/main.cpp b/cpp/int8/qat/main.cpp new file mode 100644 index 0000000000..0ed9627619 --- /dev/null +++ b/cpp/int8/qat/main.cpp @@ -0,0 +1,139 @@ +#include "torch/script.h" +#include "torch/torch.h" +#include "trtorch/ptq.h" +#include "trtorch/trtorch.h" + +#include "NvInfer.h" + +#include "cpp/int8/benchmark/benchmark.h" +#include "cpp/int8/datasets/cifar10.h" + +#include +#include +#include +#include + +namespace F = torch::nn::functional; + +struct Resize : public torch::data::transforms::TensorTransform { + Resize(std::vector new_size) : new_size_(new_size) {} + + torch::Tensor operator()(torch::Tensor input) { + input = input.unsqueeze(0); + auto upsampled = + F::interpolate(input, F::InterpolateFuncOptions().size(new_size_).align_corners(false).mode(torch::kBilinear)); + return upsampled.squeeze(0); + } + + std::vector new_size_; +}; + +torch::jit::Module compile_int8_qat_model(torch::jit::Module& mod) { + std::vector> input_shape = {{32, 3, 32, 32}}; + /// Configure settings for compilation + auto compile_spec = trtorch::CompileSpec({input_shape}); + /// Set operating precision to INT8 + compile_spec.enabled_precisions.insert(torch::kI8); + /// Set max batch size for the engine + compile_spec.max_batch_size = 32; + /// Set a larger workspace + compile_spec.workspace_size = 1 << 28; + + mod.eval(); + +#ifdef SAVE_ENGINE + std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", compile_spec); + std::ofstream out("/tmp/engine_converted_from_jit.trt"); + out << engine; + out.close(); +#endif + + std::cout << "Compiling and quantizing module" << std::endl; + auto trt_mod = trtorch::CompileGraph(mod, compile_spec); + return std::move(trt_mod); +} + +int main(int argc, const char* argv[]) { + at::globalContext().setBenchmarkCuDNN(true); + + if (argc < 3) { + std::cerr << "usage: qat \n"; + return -1; + } + + torch::jit::Module mod; + try { + /// Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(argv[1]); + } catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return -1; + } + + /// Convert the model using TensorRT + auto trt_mod = compile_int8_qat_model(mod); + std::cout << "Model conversion to TensorRT completed." << std::endl; + /// Dataloader moved into calibrator so need another for inference + const std::string data_dir = std::string(argv[2]); + auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) + .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, {0.2023, 0.1994, 0.2010})) + .map(torch::data::transforms::Stack<>()); + auto eval_dataloader = torch::data::make_data_loader( + std::move(eval_dataset), torch::data::DataLoaderOptions().batch_size(128).workers(2)); + + /// Check the FP32 accuracy in JIT + float correct = 0.0, total = 0.0; + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + auto outputs = mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + + total += targets.sizes()[0]; + correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + } + std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" + << " correct: " << correct << " total: " << total << std::endl; + + /// Check the INT8 accuracy in TRT + correct = 0.0; + total = 0.0; + for (auto batch : *eval_dataloader) { + auto images = batch.data.to(torch::kCUDA); + auto targets = batch.target.to(torch::kCUDA); + + if (images.sizes()[0] < 32) { + /// To handle smaller batches util Optimization profiles work with Int8 + auto diff = 32 - images.sizes()[0]; + auto img_padding = torch::zeros({diff, 3, 32, 32}, {torch::kCUDA}); + auto target_padding = torch::zeros({diff}, {torch::kCUDA}); + images = torch::cat({images, img_padding}, 0); + targets = torch::cat({targets, target_padding}, 0); + } + + auto outputs = trt_mod.forward({images}); + auto predictions = std::get<1>(torch::max(outputs.toTensor(), 1, false)); + predictions = predictions.reshape(predictions.sizes()[0]); + + if (predictions.sizes()[0] != targets.sizes()[0]) { + /// To handle smaller batches util Optimization profiles work with Int8 + predictions = predictions.slice(0, 0, targets.sizes()[0]); + } + + total += targets.sizes()[0]; + correct += torch::sum(torch::eq(predictions, targets)).item().toFloat(); + } + std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" + << " correct: " << correct << " total: " << total << std::endl; + + /// Time execution in JIT-FP32 and TRT-INT8 + std::vector> dims = {{32, 3, 32, 32}}; + + auto jit_runtimes = benchmark_module(mod, dims[0]); + print_avg_std_dev("JIT model FP32", jit_runtimes, dims[0][0]); + + auto trt_runtimes = benchmark_module(trt_mod, dims[0]); + print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); +} diff --git a/cpp/int8/training/vgg16/README.md b/cpp/int8/training/vgg16/README.md new file mode 100644 index 0000000000..a6e12953f9 --- /dev/null +++ b/cpp/int8/training/vgg16/README.md @@ -0,0 +1,39 @@ +# VGG16 Trained on CIFAR10 + +This is a recipe to train a VGG network on CIFAR10 to use with the TRTorch PTQ example. + +## Prequisites + +``` +pip3 install -r requirements.txt --user +``` + +## Training + +The following recipe should get somewhere between 89-92% accuracy on the CIFAR10 testset +``` +python3 main.py --lr 0.01 --batch-size 128 --drop-ratio 0.15 --ckpt-dir $(pwd)/vgg16_ckpts --epochs 100 +``` + +> 545 was the seed used in testing + +You can monitor training with tensorboard, logs are stored by default at `/tmp/vgg16_logs` + +## Exporting + +Use the exporter script to create a torchscipt module you can compile with TRTorch + +``` +python3 export_ckpt.py +``` + +It should produce a file called `trained_vgg16.jit.pt` + +Once the trained VGG network is exported run it with the PTQ example. + +## Citations + +``` +Krizhevsky, A., & Hinton, G. (2009). Learning multiple layers of features from tiny images. +Simonyan, K., & Zisserman, A. (2014). Very deep convolutional networks for large-scale image recognition. arXiv preprint arXiv:1409.1556. +``` diff --git a/cpp/int8/training/vgg16/export_ckpt.py b/cpp/int8/training/vgg16/export_ckpt.py new file mode 100644 index 0000000000..8c2e679b6b --- /dev/null +++ b/cpp/int8/training/vgg16/export_ckpt.py @@ -0,0 +1,79 @@ +import argparse +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from vgg16 import vgg16 + + +def test(model, dataloader, crit): + global writer + global classes + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + + with torch.no_grad(): + for data, labels in dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return loss / total, correct / total + + +PARSER = argparse.ArgumentParser(description="Export trained VGG") +PARSER.add_argument('ckpt', type=str, help="Path to saved checkpoint") + +args = PARSER.parse_args() +model = vgg16(num_classes=10, init_weights=False) +model = model.cuda() + +ckpt = torch.load(args.ckpt) +weights = ckpt["model_state_dict"] + +if torch.cuda.device_count() > 1: + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in weights.items(): + name = k[7:] # remove `module.` + new_state_dict[name] = v + weights = new_state_dict + +model.load_state_dict(weights) + +# Setting eval here causes both JIT and TRT accuracy to tank in LibTorch will follow up with PyTorch Team +#model.eval() + +jit_model = torch.jit.trace(model, torch.rand([32, 3, 32, 32]).to("cuda")) +jit_model.eval() + +testing_dataset = datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])) + +testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=32, shuffle=False, num_workers=2) + +crit = torch.nn.CrossEntropyLoss() + +test_loss, test_acc = test(jit_model, testing_dataloader, crit) +print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) +torch.jit.save(jit_model, "trained_vgg16.jit.pt") diff --git a/cpp/int8/training/vgg16/export_qat.py b/cpp/int8/training/vgg16/export_qat.py new file mode 100644 index 0000000000..3de373a34f --- /dev/null +++ b/cpp/int8/training/vgg16/export_qat.py @@ -0,0 +1,103 @@ +import argparse +import os + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from vgg16 import vgg16 + +from pytorch_quantization import quant_modules +from pytorch_quantization import nn as quant_nn + + +def test(model, dataloader, crit): + global writer + global classes + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + + with torch.no_grad(): + for data, labels in dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + return loss / total, correct / total + + +PARSER = argparse.ArgumentParser(description="Export trained VGG") +PARSER.add_argument('ckpt', type=str, help="Path to saved checkpoint") +PARSER.add_argument('--enable_qat', action="store_true", help="Enable quantization aware training. This is recommended to perform on a pre-trained model.") + +args = PARSER.parse_args() + +quant_modules.initialize() +model = vgg16(num_classes=10, init_weights=False) +model = model.cuda() + +ckpt = torch.load(args.ckpt) +weights = ckpt["model_state_dict"] + +# if torch.cuda.device_count() > 1: +# from collections import OrderedDict +# new_state_dict = OrderedDict() +# for k, v in weights.items(): +# name = k[7:] # remove `module.` +# new_state_dict[name] = v +# weights = new_state_dict + +# model.load_state_dict(weights) +# model.eval() +jit_model = torch.jit.load('trained_vgg16_qat.jit.pt') +testing_dataset = datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) + ])) + +testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=32, shuffle=False, num_workers=2) + +crit = torch.nn.CrossEntropyLoss() + +# +# quant_nn.TensorQuantizer.use_fb_fake_quant = True +# with torch.no_grad(): +# data = iter(testing_dataloader) +# images, _ = data.next() +# jit_model = torch.jit.trace(model, images.to("cuda")) +# # jit_model.eval() +# torch.jit.save(jit_model, "trained_vgg16_qat.jit.pt") +# +test_loss, test_acc = test(jit_model, testing_dataloader, crit) +print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) +# +# if args.enable_qat: +# quant_nn.TensorQuantizer.use_fb_fake_quant = True + +import trtorch +# trtorch.logging.set_reportable_log_level(trtorch.logging.Level.Debug) +compile_settings = { +"input_shapes": [[1, 3, 32, 32]], +"op_precision": torch.int8 # Run with FP16 +} +new_mod = torch.jit.load('trained_vgg16_qat.jit.pt') +trt_ts_module = trtorch.compile(new_mod, compile_settings) +testing_dataloader = torch.utils.data.DataLoader(testing_dataset, batch_size=1, shuffle=False, num_workers=2) +test_loss, test_acc = test(trt_ts_module, testing_dataloader, crit) +print("[JIT] Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) diff --git a/cpp/int8/training/vgg16/main.py b/cpp/int8/training/vgg16/main.py new file mode 100644 index 0000000000..627688cf9b --- /dev/null +++ b/cpp/int8/training/vgg16/main.py @@ -0,0 +1,224 @@ +import argparse +import os +import random +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from torch.utils.tensorboard import SummaryWriter + +from vgg16 import vgg16 + +PARSER = argparse.ArgumentParser(description="VGG16 example to use with TRTorch PTQ") +PARSER.add_argument('--epochs', default=100, type=int, help="Number of total epochs to train") +PARSER.add_argument('--batch-size', default=128, type=int, help="Batch size to use when training") +PARSER.add_argument('--lr', default=0.1, type=float, help="Initial learning rate") +PARSER.add_argument('--drop-ratio', default=0., type=float, help="Dropout ratio") +PARSER.add_argument('--momentum', default=0.9, type=float, help="Momentum") +PARSER.add_argument('--weight-decay', default=5e-4, type=float, help="Weight decay") +PARSER.add_argument('--ckpt-dir', + default="/tmp/vgg16_ckpts", + type=str, + help="Path to save checkpoints (saved every 10 epochs)") +PARSER.add_argument('--start-from', + default=0, + type=int, + help="Epoch to resume from (requires a checkpoin in the providied checkpoi") +PARSER.add_argument('--seed', type=int, help='Seed value for rng') +PARSER.add_argument('--tensorboard', type=str, default='/tmp/vgg16_logs', help='Location for tensorboard info') + +args = PARSER.parse_args() +for arg in vars(args): + print(' {} {}'.format(arg, getattr(args, arg))) +state = {k: v for k, v in args._get_kwargs()} + +if args.seed is None: + args.seed = random.randint(1, 10000) +random.seed(args.seed) +torch.manual_seed(args.seed) +torch.cuda.manual_seed_all(args.seed) +print("RNG seed used: ", args.seed) + +now = datetime.now() + +timestamp = datetime.timestamp(now) + +writer = SummaryWriter(args.tensorboard + '/test_' + str(timestamp)) +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + + +def main(): + global state + global classes + global writer + if not os.path.isdir(args.ckpt_dir): + os.makedirs(args.ckpt_dir) + + training_dataset = datasets.CIFAR10(root='./data', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ])) + training_dataloader = torch.utils.data.DataLoader(training_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=2) + + testing_dataset = datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ])) + + testing_dataloader = torch.utils.data.DataLoader(testing_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=2) + + num_classes = len(classes) + + model = vgg16(num_classes=num_classes, init_weights=False) + model = model.cuda() + + data = iter(training_dataloader) + images, _ = data.next() + + writer.add_graph(model, images.cuda()) + writer.close() + + crit = nn.CrossEntropyLoss() + opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + + if args.start_from != 0: + ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth' + print('Loading from checkpoint {}'.format(ckpt_file)) + assert (os.path.isfile(ckpt_file)) + ckpt = torch.load(ckpt_file) + model.load_state_dict(ckpt["model_state_dict"]) + opt.load_state_dict(ckpt["opt_state_dict"]) + state = ckpt["state"] + + for epoch in range(args.start_from, args.epochs): + adjust_lr(opt, epoch) + writer.add_scalar('Learning Rate', state["lr"], epoch) + writer.close() + print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) + + train(model, training_dataloader, crit, opt, epoch) + test_loss, test_acc = test(model, testing_dataloader, crit, epoch) + + print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) + + if epoch % 10 == 9: + save_checkpoint( + { + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'acc': test_acc, + 'opt_state_dict': opt.state_dict(), + 'state': state + }, + ckpt_dir=args.ckpt_dir) + + +def train(model, dataloader, crit, opt, epoch): + global writer + model.train() + running_loss = 0.0 + for batch, (data, labels) in enumerate(dataloader): + data, labels = data.cuda(), labels.cuda(non_blocking=True) + opt.zero_grad() + out = model(data) + loss = crit(out, labels) + loss.backward() + opt.step() + + running_loss += loss.item() + if batch % 50 == 49: + writer.add_scalar('Training Loss', running_loss / 100, epoch * len(dataloader) + batch) + writer.close() + print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100)) + running_loss = 0.0 + + +def test(model, dataloader, crit, epoch): + global writer + global classes + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + model.eval() + with torch.no_grad(): + for data, labels in dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + writer.add_scalar('Testing Loss', loss / total, epoch) + writer.close() + + writer.add_scalar('Testing Accuracy', correct / total * 100, epoch) + writer.close() + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + for i in range(len(classes)): + add_pr_curve_tensorboard(i, test_probs, test_preds, epoch) + #print(loss, total, correct, total) + return loss / total, correct / total + + +def save_checkpoint(state, ckpt_dir='checkpoint'): + print("Checkpoint {} saved".format(state['epoch'])) + filename = "ckpt_epoch" + str(state['epoch']) + ".pth" + filepath = os.path.join(ckpt_dir, filename) + torch.save(state, filepath) + + +def adjust_lr(optimizer, epoch): + global state + new_lr = state["lr"] * (0.5**(epoch // 40)) if state["lr"] > 1e-7 else state["lr"] + if new_lr != state["lr"]: + state["lr"] = new_lr + print("Updating learning rate: {}".format(state["lr"])) + for param_group in optimizer.param_groups: + param_group["lr"] = state["lr"] + + +def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0): + global classes + ''' + Takes in a "class_index" from 0 to 9 and plots the corresponding + precision-recall curve + ''' + tensorboard_preds = test_preds == class_index + tensorboard_probs = test_probs[:, class_index] + + writer.add_pr_curve(classes[class_index], tensorboard_preds, tensorboard_probs, global_step=global_step) + writer.close() + + +if __name__ == "__main__": + main() diff --git a/cpp/int8/training/vgg16/requirements.txt b/cpp/int8/training/vgg16/requirements.txt new file mode 100644 index 0000000000..c6bebeaec2 --- /dev/null +++ b/cpp/int8/training/vgg16/requirements.txt @@ -0,0 +1,2 @@ +torch>=1.4.0 +tensorboard>=1.14.0 \ No newline at end of file diff --git a/cpp/int8/training/vgg16/train_qat.py b/cpp/int8/training/vgg16/train_qat.py new file mode 100644 index 0000000000..74c55ff45c --- /dev/null +++ b/cpp/int8/training/vgg16/train_qat.py @@ -0,0 +1,329 @@ +import argparse +import os +import random +from datetime import datetime + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +import torch.utils.data as data +import torchvision.transforms as transforms +import torchvision.datasets as datasets + +from torch.utils.tensorboard import SummaryWriter + +from pytorch_quantization import nn as quant_nn +from pytorch_quantization import quant_modules +from pytorch_quantization.tensor_quant import QuantDescriptor +from pytorch_quantization import calib +from tqdm import tqdm + +from vgg16 import vgg16 + +PARSER = argparse.ArgumentParser(description="VGG16 example to use with TRTorch PTQ") +PARSER.add_argument('--epochs', default=100, type=int, help="Number of total epochs to train") +PARSER.add_argument('--enable_qat', action="store_true", help="Enable quantization aware training. This is recommended to perform on a pre-trained model.") +PARSER.add_argument('--batch-size', default=128, type=int, help="Batch size to use when training") +PARSER.add_argument('--lr', default=0.1, type=float, help="Initial learning rate") +PARSER.add_argument('--drop-ratio', default=0., type=float, help="Dropout ratio") +PARSER.add_argument('--momentum', default=0.9, type=float, help="Momentum") +PARSER.add_argument('--weight-decay', default=5e-4, type=float, help="Weight decay") +PARSER.add_argument('--ckpt-dir', + default="/tmp/vgg16_ckpts", + type=str, + help="Path to save checkpoints (saved every 10 epochs)") +PARSER.add_argument('--start-from', + default=0, + type=int, + help="Epoch to resume from (requires a checkpoin in the providied checkpoi") +PARSER.add_argument('--seed', type=int, help='Seed value for rng') +PARSER.add_argument('--tensorboard', type=str, default='/tmp/vgg16_logs', help='Location for tensorboard info') + +args = PARSER.parse_args() +for arg in vars(args): + print(' {} {}'.format(arg, getattr(args, arg))) +state = {k: v for k, v in args._get_kwargs()} + +if args.seed is None: + args.seed = random.randint(1, 10000) +random.seed(args.seed) +torch.manual_seed(args.seed) +torch.cuda.manual_seed_all(args.seed) +print("RNG seed used: ", args.seed) + +now = datetime.now() + +timestamp = datetime.timestamp(now) + +writer = SummaryWriter(args.tensorboard + '/test_' + str(timestamp)) +classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck') + +def compute_amax(model, **kwargs): + # Load calib result + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + if isinstance(module._calibrator, calib.MaxCalibrator): + module.load_calib_amax() + else: + module.load_calib_amax(**kwargs) + print(F"{name:40}: {module}") + model.cuda() + +def collect_stats(model, data_loader, num_batches): + """Feed data to the network and collect statistics""" + # Enable calibrators + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + module.disable_quant() + module.enable_calib() + else: + module.disable() + + # Feed data to the network for collecting stats + for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches): + model(image.cuda()) + if i >= num_batches: + break + + # Disable calibrators + for name, module in model.named_modules(): + if isinstance(module, quant_nn.TensorQuantizer): + if module._calibrator is not None: + module.enable_quant() + module.disable_calib() + else: + module.enable() + +def calibrate_model(model, model_name, data_loader, num_calib_batch, calibrator, hist_percentile, out_dir): + """ + Feed data to the network and calibrate. + Arguments: + model: classification model + model_name: name to use when creating state files + data_loader: calibration data set + num_calib_batch: amount of calibration passes to perform + calibrator: type of calibration to use (max/histogram) + hist_percentile: percentiles to be used for historgram calibration + out_dir: dir to save state files in + """ + + if num_calib_batch > 0: + print("Calibrating model") + with torch.no_grad(): + collect_stats(model, data_loader, num_calib_batch) + + if not calibrator == "histogram": + compute_amax(model, method="max") + calib_output = os.path.join( + out_dir, + F"{model_name}-max-{num_calib_batch*data_loader.batch_size}.pth") + torch.save(model.state_dict(), calib_output) + else: + for percentile in hist_percentile: + print(F"{percentile} percentile calibration") + compute_amax(model, method="percentile") + calib_output = os.path.join( + out_dir, + F"{model_name}-percentile-{percentile}-{num_calib_batch*data_loader.batch_size}.pth") + torch.save(model.state_dict(), calib_output) + + for method in ["mse", "entropy"]: + print(F"{method} calibration") + compute_amax(model, method=method) + calib_output = os.path.join( + out_dir, + F"{model_name}-{method}-{num_calib_batch*data_loader.batch_size}.pth") + torch.save(model.state_dict(), calib_output) + + +def main(): + + global state + global classes + global writer + if not os.path.isdir(args.ckpt_dir): + os.makedirs(args.ckpt_dir) + + training_dataset = datasets.CIFAR10(root='./data', + train=True, + download=True, + transform=transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ])) + training_dataloader = torch.utils.data.DataLoader(training_dataset, + batch_size=args.batch_size, + shuffle=True, + num_workers=2) + + testing_dataset = datasets.CIFAR10(root='./data', + train=False, + download=True, + transform=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), + ])) + + testing_dataloader = torch.utils.data.DataLoader(testing_dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=2) + + num_classes = len(classes) + + quant_modules.initialize() + + model = vgg16(num_classes=num_classes, init_weights=False) + model = model.cuda() + + crit = nn.CrossEntropyLoss() + opt = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay) + + if args.start_from != 0: + ckpt_file = args.ckpt_dir + '/ckpt_epoch' + str(args.start_from) + '.pth' + print('Loading from checkpoint {}'.format(ckpt_file)) + assert (os.path.isfile(ckpt_file)) + ckpt = torch.load(ckpt_file) + modified_state_dict={} + + for key, val in ckpt["model_state_dict"].items(): + modified_state_dict[key] = val + + model.load_state_dict(modified_state_dict) + opt.load_state_dict(ckpt["opt_state_dict"]) + state = ckpt["state"] + + data = iter(training_dataloader) + images, _ = data.next() + + writer.add_graph(model, images.cuda()) + writer.close() + + # ## Calibrate the model + # with torch.no_grad(): + # calibrate_model( + # model=model, + # model_name="vgg16", + # data_loader=training_dataloader, + # num_calib_batch=32, + # calibrator="max", + # hist_percentile=[99.9, 99.99, 99.999, 99.9999], + # out_dir="./") + + if torch.cuda.device_count() > 1: + model = nn.DataParallel(model) + + for epoch in range(args.start_from, args.epochs): + adjust_lr(opt, epoch) + writer.add_scalar('Learning Rate', state["lr"], epoch) + writer.close() + print('Epoch: [%5d / %5d] LR: %f' % (epoch + 1, args.epochs, state['lr'])) + + train(model, training_dataloader, crit, opt, epoch) + test_loss, test_acc = test(model, testing_dataloader, crit, epoch) + + print("Test Loss: {:.5f} Test Acc: {:.2f}%".format(test_loss, 100 * test_acc)) + + if epoch % 10 == 9 or epoch == args.epochs-1: + save_checkpoint( + { + 'epoch': epoch + 1, + 'model_state_dict': model.state_dict(), + 'acc': test_acc, + 'opt_state_dict': opt.state_dict(), + 'state': state + }, + ckpt_dir=args.ckpt_dir) + + +def train(model, dataloader, crit, opt, epoch): + global writer + model.train() + running_loss = 0.0 + for batch, (data, labels) in enumerate(dataloader): + data, labels = data.cuda(), labels.cuda(non_blocking=True) + opt.zero_grad() + out = model(data) + loss = crit(out, labels) + loss.backward() + opt.step() + + running_loss += loss.item() + if batch % 50 == 49: + writer.add_scalar('Training Loss', running_loss / 100, epoch * len(dataloader) + batch) + writer.close() + print("Batch: [%5d | %5d] loss: %.3f" % (batch + 1, len(dataloader), running_loss / 100)) + running_loss = 0.0 + + +def test(model, dataloader, crit, epoch): + global writer + global classes + total = 0 + correct = 0 + loss = 0.0 + class_probs = [] + class_preds = [] + model.eval() + with torch.no_grad(): + for data, labels in dataloader: + data, labels = data.cuda(), labels.cuda(non_blocking=True) + out = model(data) + loss += crit(out, labels) + preds = torch.max(out, 1)[1] + class_probs.append([F.softmax(i, dim=0) for i in out]) + class_preds.append(preds) + total += labels.size(0) + correct += (preds == labels).sum().item() + + writer.add_scalar('Testing Loss', loss / total, epoch) + writer.close() + + writer.add_scalar('Testing Accuracy', correct / total * 100, epoch) + writer.close() + + test_probs = torch.cat([torch.stack(batch) for batch in class_probs]) + test_preds = torch.cat(class_preds) + for i in range(len(classes)): + add_pr_curve_tensorboard(i, test_probs, test_preds, epoch) + print(loss, total, correct, total) + return loss / total, correct / total + + +def save_checkpoint(state, ckpt_dir='checkpoint'): + print("Checkpoint {} saved".format(state['epoch'])) + filename = "ckpt_epoch" + str(state['epoch']) + ".pth" + filepath = os.path.join(ckpt_dir, filename) + torch.save(state, filepath) + + +def adjust_lr(optimizer, epoch): + global state + new_lr = state["lr"] * (0.5**(epoch // 40)) if state["lr"] > 1e-7 else state["lr"] + if new_lr != state["lr"]: + state["lr"] = new_lr + print("Updating learning rate: {}".format(state["lr"])) + for param_group in optimizer.param_groups: + param_group["lr"] = state["lr"] + + +def add_pr_curve_tensorboard(class_index, test_probs, test_preds, global_step=0): + global classes + ''' + Takes in a "class_index" from 0 to 9 and plots the corresponding + precision-recall curve + ''' + tensorboard_preds = test_preds == class_index + tensorboard_probs = test_probs[:, class_index] + + writer.add_pr_curve(classes[class_index], tensorboard_preds, tensorboard_probs, global_step=global_step) + writer.close() + + +if __name__ == "__main__": + main() diff --git a/cpp/int8/training/vgg16/vgg16.py b/cpp/int8/training/vgg16/vgg16.py new file mode 100644 index 0000000000..6b3ddcd898 --- /dev/null +++ b/cpp/int8/training/vgg16/vgg16.py @@ -0,0 +1,56 @@ +''' +# Reference +- [Very Deep Convolutional Networks for Large-Scale Image Recognition]( + https://arxiv.org/abs/1409.1556) (ICLR 2015) +''' +import torch +import torch.nn as nn +import torch.nn.functional as F +from functools import reduce + + +class VGG(nn.Module): + + def __init__(self, layer_spec, num_classes=1000, init_weights=False): + super(VGG, self).__init__() + + layers = [] + in_channels = 3 + for l in layer_spec: + if l == 'pool': + layers.append(nn.MaxPool2d(kernel_size=2, stride=2)) + else: + layers += [nn.Conv2d(in_channels, l, kernel_size=3, padding=1), nn.BatchNorm2d(l), nn.ReLU()] + in_channels = l + + self.features = nn.Sequential(*layers) + self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + self.classifier = nn.Sequential(nn.Linear(512 * 1 * 1, 4096), nn.ReLU(), nn.Dropout(), nn.Linear(4096, 4096), + nn.ReLU(), nn.Dropout(), nn.Linear(4096, num_classes)) + if init_weights: + self._initialize_weights() + + def _initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.BatchNorm2d): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Linear): + nn.init.normal_(m.weight, 0, 0.01) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.features(x) + x = self.avgpool(x) + x = torch.flatten(x, 1) + x = self.classifier(x) + return x + + +def vgg16(num_classes=1000, init_weights=False): + vgg16_cfg = [64, 64, 'pool', 128, 128, 'pool', 256, 256, 256, 'pool', 512, 512, 512, 'pool', 512, 512, 512, 'pool'] + return VGG(vgg16_cfg, num_classes, init_weights)