-
Notifications
You must be signed in to change notification settings - Fork 360
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(//cpp/int8/qat): QAT application release
Signed-off-by: Dheeraj Peri <[email protected]>
- Loading branch information
Showing
20 changed files
with
1,819 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
cc_library( | ||
name = "benchmark", | ||
srcs = [ | ||
"benchmark.cpp", | ||
"timer.h", | ||
], | ||
hdrs = [ | ||
"benchmark.h", | ||
], | ||
deps = [ | ||
"//cpp/api:trtorch", | ||
"@libtorch", | ||
"@libtorch//:caffe2", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
#include "ATen/Context.h" | ||
#include "c10/cuda/CUDACachingAllocator.h" | ||
#include "cuda_runtime_api.h" | ||
#include "torch/script.h" | ||
#include "torch/torch.h" | ||
#include "trtorch/trtorch.h" | ||
|
||
#include "timer.h" | ||
|
||
#define NUM_WARMUP_RUNS 20 | ||
#define NUM_RUNS 100 | ||
|
||
// Benchmaking code | ||
void print_avg_std_dev(std::string type, std::vector<float>& runtimes, uint64_t batch_size) { | ||
float avg_runtime = std::accumulate(runtimes.begin(), runtimes.end(), 0.0) / runtimes.size(); | ||
float fps = (1000.f / avg_runtime) * batch_size; | ||
std::cout << "[" << type << "]: batch_size: " << batch_size << "\n Average latency: " << avg_runtime | ||
<< " ms\n Average FPS: " << fps << " fps" << std::endl; | ||
|
||
std::vector<float> rt_diff(runtimes.size()); | ||
std::transform(runtimes.begin(), runtimes.end(), rt_diff.begin(), [avg_runtime](float x) { return x - avg_runtime; }); | ||
float rt_sq_sum = std::inner_product(rt_diff.begin(), rt_diff.end(), rt_diff.begin(), 0.0); | ||
float rt_std_dev = std::sqrt(rt_sq_sum / runtimes.size()); | ||
|
||
std::vector<float> fps_diff(runtimes.size()); | ||
std::transform(runtimes.begin(), runtimes.end(), fps_diff.begin(), [fps, batch_size](float x) { | ||
return ((1000.f / x) * batch_size) - fps; | ||
}); | ||
float fps_sq_sum = std::inner_product(fps_diff.begin(), fps_diff.end(), fps_diff.begin(), 0.0); | ||
float fps_std_dev = std::sqrt(fps_sq_sum / runtimes.size()); | ||
std::cout << " Latency Standard Deviation: " << rt_std_dev << "\n FPS Standard Deviation: " << fps_std_dev | ||
<< "\n(excluding initial warmup runs)" << std::endl; | ||
} | ||
|
||
std::vector<float> benchmark_module(torch::jit::script::Module& mod, std::vector<int64_t> shape) { | ||
auto execution_timer = timers::PreciseCPUTimer(); | ||
std::vector<float> execution_runtimes; | ||
|
||
for (uint64_t i = 0; i < NUM_WARMUP_RUNS; i++) { | ||
std::vector<torch::jit::IValue> inputs_ivalues; | ||
auto in = at::rand(shape, {at::kCUDA}); | ||
#ifdef HALF | ||
in = in.to(torch::kHalf); | ||
#endif | ||
inputs_ivalues.push_back(in.clone()); | ||
|
||
cudaDeviceSynchronize(); | ||
mod.forward(inputs_ivalues); | ||
cudaDeviceSynchronize(); | ||
} | ||
|
||
for (uint64_t i = 0; i < NUM_RUNS; i++) { | ||
std::vector<torch::jit::IValue> inputs_ivalues; | ||
auto in = at::rand(shape, {at::kCUDA}); | ||
#ifdef HALF | ||
in = in.to(torch::kHalf); | ||
#endif | ||
inputs_ivalues.push_back(in.clone()); | ||
cudaDeviceSynchronize(); | ||
|
||
execution_timer.start(); | ||
mod.forward(inputs_ivalues); | ||
cudaDeviceSynchronize(); | ||
execution_timer.stop(); | ||
|
||
auto time = execution_timer.milliseconds(); | ||
execution_timer.reset(); | ||
execution_runtimes.push_back(time); | ||
|
||
c10::cuda::CUDACachingAllocator::emptyCache(); | ||
} | ||
return execution_runtimes; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
#pragma once | ||
|
||
void print_avg_std_dev(std::string type, std::vector<float>& runtimes, uint64_t batch_size); | ||
std::vector<float> benchmark_module(torch::jit::script::Module& mod, std::vector<int64_t> shape); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
#pragma once | ||
#include <chrono> | ||
|
||
namespace timers { | ||
class TimerBase { | ||
public: | ||
virtual void start() {} | ||
virtual void stop() {} | ||
float microseconds() const noexcept { | ||
return mMs * 1000.f; | ||
} | ||
float milliseconds() const noexcept { | ||
return mMs; | ||
} | ||
float seconds() const noexcept { | ||
return mMs / 1000.f; | ||
} | ||
void reset() noexcept { | ||
mMs = 0.f; | ||
} | ||
|
||
protected: | ||
float mMs{0.0f}; | ||
}; | ||
|
||
template <typename Clock> | ||
class CPUTimer : public TimerBase { | ||
public: | ||
using clock_type = Clock; | ||
|
||
void start() { | ||
mStart = Clock::now(); | ||
} | ||
void stop() { | ||
mStop = Clock::now(); | ||
mMs += std::chrono::duration<float, std::milli>{mStop - mStart}.count(); | ||
} | ||
|
||
private: | ||
std::chrono::time_point<Clock> mStart, mStop; | ||
}; // class CPUTimer | ||
|
||
using PreciseCPUTimer = CPUTimer<std::chrono::high_resolution_clock>; | ||
} // namespace timers |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
cc_library( | ||
name = "cifar10", | ||
srcs = [ | ||
"cifar10.cpp", | ||
], | ||
hdrs = [ | ||
"cifar10.h", | ||
], | ||
deps = [ | ||
"@libtorch", | ||
], | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
// #include "cpp/int8/ptq/datasets/cifar10.h" | ||
#include "cifar10.h" | ||
#include "torch/data/example.h" | ||
#include "torch/torch.h" | ||
#include "torch/types.h" | ||
|
||
#include <cstddef> | ||
#include <fstream> | ||
#include <iostream> | ||
#include <memory> | ||
#include <sstream> | ||
#include <string> | ||
#include <utility> | ||
#include <vector> | ||
|
||
namespace datasets { | ||
namespace { | ||
constexpr const char* kTrainFilenamePrefix = "data_batch_"; | ||
constexpr const uint32_t kNumTrainFiles = 5; | ||
constexpr const char* kTestFilename = "test_batch.bin"; | ||
constexpr const size_t kLabelSize = 1; // B | ||
constexpr const size_t kImageSize = 3072; // B | ||
constexpr const size_t kImageDim = 32; | ||
constexpr const size_t kImageChannels = 3; | ||
constexpr const size_t kBatchSize = 10000; | ||
|
||
std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) { | ||
std::ifstream batch; | ||
batch.open(path, std::ios::in | std::ios::binary | std::ios::ate); | ||
|
||
auto file_size = batch.tellg(); | ||
std::unique_ptr<char[]> buf(new char[file_size]); | ||
|
||
batch.seekg(0, std::ios::beg); | ||
batch.read(buf.get(), file_size); | ||
batch.close(); | ||
|
||
std::vector<uint8_t> labels; | ||
std::vector<torch::Tensor> images; | ||
labels.reserve(kBatchSize); | ||
images.reserve(kBatchSize); | ||
|
||
for (size_t i = 0; i < kBatchSize; i++) { | ||
uint8_t label = buf[i * (kImageSize + kLabelSize)]; | ||
std::vector<uint8_t> image; | ||
image.reserve(kImageSize); | ||
std::copy( | ||
&buf[i * (kImageSize + kLabelSize) + 1], | ||
&buf[i * (kImageSize + kLabelSize) + kImageSize], | ||
std::back_inserter(image)); | ||
labels.push_back(label); | ||
auto image_tensor = | ||
torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8)) | ||
.to(torch::kF32); | ||
images.push_back(image_tensor); | ||
} | ||
|
||
auto labels_tensor = | ||
torch::from_blob(labels.data(), {kBatchSize}, torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32); | ||
assert(labels_tensor.size(0) == kBatchSize); | ||
|
||
auto images_tensor = torch::stack(images); | ||
assert(images_tensor.size(0) == kBatchSize); | ||
|
||
return std::make_pair(images_tensor, labels_tensor); | ||
} | ||
|
||
std::pair<torch::Tensor, torch::Tensor> read_train_data(const std::string& root) { | ||
std::vector<torch::Tensor> images, targets; | ||
for (uint32_t i = 1; i <= 5; i++) { | ||
std::stringstream ss; | ||
ss << root << '/' << kTrainFilenamePrefix << i << ".bin"; | ||
auto batch = read_batch(ss.str()); | ||
images.push_back(batch.first); | ||
targets.push_back(batch.second); | ||
} | ||
|
||
torch::Tensor image_tensor = | ||
std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) { | ||
return torch::cat({a, b}, 0); | ||
}); | ||
torch::Tensor target_tensor = | ||
std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) { | ||
return torch::cat({a, b}, 0); | ||
}); | ||
|
||
return std::make_pair(image_tensor, target_tensor); | ||
} | ||
|
||
std::pair<torch::Tensor, torch::Tensor> read_test_data(const std::string& root) { | ||
std::stringstream ss; | ||
ss << root << '/' << kTestFilename; | ||
return read_batch(ss.str()); | ||
} | ||
} // namespace | ||
|
||
CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) { | ||
std::pair<torch::Tensor, torch::Tensor> data; | ||
if (mode_ == Mode::kTrain) { | ||
data = read_train_data(root); | ||
} else { | ||
data = read_test_data(root); | ||
} | ||
|
||
images_ = std::move(data.first); | ||
targets_ = std::move(data.second); | ||
assert(images_.sizes()[0] == images_.sizes()[0]); | ||
} | ||
|
||
torch::data::Example<> CIFAR10::get(size_t index) { | ||
return {images_[index], targets_[index]}; | ||
} | ||
|
||
c10::optional<size_t> CIFAR10::size() const { | ||
return images_.size(0); | ||
} | ||
|
||
bool CIFAR10::is_train() const noexcept { | ||
return mode_ == Mode::kTrain; | ||
} | ||
|
||
const torch::Tensor& CIFAR10::images() const { | ||
return images_; | ||
} | ||
|
||
const torch::Tensor& CIFAR10::targets() const { | ||
return targets_; | ||
} | ||
|
||
CIFAR10&& CIFAR10::use_subset(int64_t new_size) { | ||
assert(new_size <= images_.sizes()[0]); | ||
images_ = images_.slice(0, 0, new_size); | ||
targets_ = targets_.slice(0, 0, new_size); | ||
return std::move(*this); | ||
} | ||
|
||
} // namespace datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
#pragma once | ||
|
||
#include "torch/data/datasets/base.h" | ||
#include "torch/data/example.h" | ||
#include "torch/types.h" | ||
|
||
#include <cstddef> | ||
#include <string> | ||
|
||
namespace datasets { | ||
// The CIFAR10 Dataset | ||
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> { | ||
public: | ||
// The mode in which the dataset is loaded | ||
enum class Mode { kTrain, kTest }; | ||
|
||
// Loads CIFAR10 from un-tarred file | ||
// Dataset can be found | ||
// https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz Root path should be | ||
// the directory that contains the content of tarball | ||
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain); | ||
|
||
// Returns the pair at index in the dataset | ||
torch::data::Example<> get(size_t index) override; | ||
|
||
// The size of the dataset | ||
c10::optional<size_t> size() const override; | ||
|
||
// The mode the dataset is in | ||
bool is_train() const noexcept; | ||
|
||
// Returns all images stacked into a single tensor | ||
const torch::Tensor& images() const; | ||
|
||
// Returns all targets stacked into a single tensor | ||
const torch::Tensor& targets() const; | ||
|
||
// Trims the dataset to the first n pairs | ||
CIFAR10&& use_subset(int64_t new_size); | ||
|
||
private: | ||
Mode mode_; | ||
torch::Tensor images_, targets_; | ||
}; | ||
} // namespace datasets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
package(default_visibility = ["//visibility:public"]) | ||
|
||
cc_binary( | ||
name = "ptq", | ||
srcs = [ | ||
"main.cpp", | ||
], | ||
copts = [ | ||
"-pthread", | ||
], | ||
linkopts = [ | ||
"-lpthread", | ||
], | ||
deps = [ | ||
"//cpp/api:trtorch", | ||
"//cpp/int8/benchmark", | ||
"//cpp/int8/datasets:cifar10", | ||
"@libtorch", | ||
"@libtorch//:caffe2", | ||
"@tensorrt//:nvinfer", | ||
], | ||
) |
Oops, something went wrong.