Skip to content

Commit

Permalink
feat(//cpp/int8/qat): QAT application release
Browse files Browse the repository at this point in the history
Signed-off-by: Dheeraj Peri <[email protected]>
  • Loading branch information
peri044 committed Jul 30, 2021
1 parent 004bf53 commit d8f5d29
Show file tree
Hide file tree
Showing 20 changed files with 1,819 additions and 0 deletions.
17 changes: 17 additions & 0 deletions cpp/int8/benchmark/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package(default_visibility = ["//visibility:public"])

cc_library(
name = "benchmark",
srcs = [
"benchmark.cpp",
"timer.h",
],
hdrs = [
"benchmark.h",
],
deps = [
"//cpp/api:trtorch",
"@libtorch",
"@libtorch//:caffe2",
],
)
73 changes: 73 additions & 0 deletions cpp/int8/benchmark/benchmark.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#include "ATen/Context.h"
#include "c10/cuda/CUDACachingAllocator.h"
#include "cuda_runtime_api.h"
#include "torch/script.h"
#include "torch/torch.h"
#include "trtorch/trtorch.h"

#include "timer.h"

#define NUM_WARMUP_RUNS 20
#define NUM_RUNS 100

// Benchmaking code
void print_avg_std_dev(std::string type, std::vector<float>& runtimes, uint64_t batch_size) {
float avg_runtime = std::accumulate(runtimes.begin(), runtimes.end(), 0.0) / runtimes.size();
float fps = (1000.f / avg_runtime) * batch_size;
std::cout << "[" << type << "]: batch_size: " << batch_size << "\n Average latency: " << avg_runtime
<< " ms\n Average FPS: " << fps << " fps" << std::endl;

std::vector<float> rt_diff(runtimes.size());
std::transform(runtimes.begin(), runtimes.end(), rt_diff.begin(), [avg_runtime](float x) { return x - avg_runtime; });
float rt_sq_sum = std::inner_product(rt_diff.begin(), rt_diff.end(), rt_diff.begin(), 0.0);
float rt_std_dev = std::sqrt(rt_sq_sum / runtimes.size());

std::vector<float> fps_diff(runtimes.size());
std::transform(runtimes.begin(), runtimes.end(), fps_diff.begin(), [fps, batch_size](float x) {
return ((1000.f / x) * batch_size) - fps;
});
float fps_sq_sum = std::inner_product(fps_diff.begin(), fps_diff.end(), fps_diff.begin(), 0.0);
float fps_std_dev = std::sqrt(fps_sq_sum / runtimes.size());
std::cout << " Latency Standard Deviation: " << rt_std_dev << "\n FPS Standard Deviation: " << fps_std_dev
<< "\n(excluding initial warmup runs)" << std::endl;
}

std::vector<float> benchmark_module(torch::jit::script::Module& mod, std::vector<int64_t> shape) {
auto execution_timer = timers::PreciseCPUTimer();
std::vector<float> execution_runtimes;

for (uint64_t i = 0; i < NUM_WARMUP_RUNS; i++) {
std::vector<torch::jit::IValue> inputs_ivalues;
auto in = at::rand(shape, {at::kCUDA});
#ifdef HALF
in = in.to(torch::kHalf);
#endif
inputs_ivalues.push_back(in.clone());

cudaDeviceSynchronize();
mod.forward(inputs_ivalues);
cudaDeviceSynchronize();
}

for (uint64_t i = 0; i < NUM_RUNS; i++) {
std::vector<torch::jit::IValue> inputs_ivalues;
auto in = at::rand(shape, {at::kCUDA});
#ifdef HALF
in = in.to(torch::kHalf);
#endif
inputs_ivalues.push_back(in.clone());
cudaDeviceSynchronize();

execution_timer.start();
mod.forward(inputs_ivalues);
cudaDeviceSynchronize();
execution_timer.stop();

auto time = execution_timer.milliseconds();
execution_timer.reset();
execution_runtimes.push_back(time);

c10::cuda::CUDACachingAllocator::emptyCache();
}
return execution_runtimes;
}
4 changes: 4 additions & 0 deletions cpp/int8/benchmark/benchmark.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#pragma once

void print_avg_std_dev(std::string type, std::vector<float>& runtimes, uint64_t batch_size);
std::vector<float> benchmark_module(torch::jit::script::Module& mod, std::vector<int64_t> shape);
44 changes: 44 additions & 0 deletions cpp/int8/benchmark/timer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#pragma once
#include <chrono>

namespace timers {
class TimerBase {
public:
virtual void start() {}
virtual void stop() {}
float microseconds() const noexcept {
return mMs * 1000.f;
}
float milliseconds() const noexcept {
return mMs;
}
float seconds() const noexcept {
return mMs / 1000.f;
}
void reset() noexcept {
mMs = 0.f;
}

protected:
float mMs{0.0f};
};

template <typename Clock>
class CPUTimer : public TimerBase {
public:
using clock_type = Clock;

void start() {
mStart = Clock::now();
}
void stop() {
mStop = Clock::now();
mMs += std::chrono::duration<float, std::milli>{mStop - mStart}.count();
}

private:
std::chrono::time_point<Clock> mStart, mStop;
}; // class CPUTimer

using PreciseCPUTimer = CPUTimer<std::chrono::high_resolution_clock>;
} // namespace timers
14 changes: 14 additions & 0 deletions cpp/int8/datasets/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
package(default_visibility = ["//visibility:public"])

cc_library(
name = "cifar10",
srcs = [
"cifar10.cpp",
],
hdrs = [
"cifar10.h",
],
deps = [
"@libtorch",
],
)
137 changes: 137 additions & 0 deletions cpp/int8/datasets/cifar10.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
// #include "cpp/int8/ptq/datasets/cifar10.h"
#include "cifar10.h"
#include "torch/data/example.h"
#include "torch/torch.h"
#include "torch/types.h"

#include <cstddef>
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <utility>
#include <vector>

namespace datasets {
namespace {
constexpr const char* kTrainFilenamePrefix = "data_batch_";
constexpr const uint32_t kNumTrainFiles = 5;
constexpr const char* kTestFilename = "test_batch.bin";
constexpr const size_t kLabelSize = 1; // B
constexpr const size_t kImageSize = 3072; // B
constexpr const size_t kImageDim = 32;
constexpr const size_t kImageChannels = 3;
constexpr const size_t kBatchSize = 10000;

std::pair<torch::Tensor, torch::Tensor> read_batch(const std::string& path) {
std::ifstream batch;
batch.open(path, std::ios::in | std::ios::binary | std::ios::ate);

auto file_size = batch.tellg();
std::unique_ptr<char[]> buf(new char[file_size]);

batch.seekg(0, std::ios::beg);
batch.read(buf.get(), file_size);
batch.close();

std::vector<uint8_t> labels;
std::vector<torch::Tensor> images;
labels.reserve(kBatchSize);
images.reserve(kBatchSize);

for (size_t i = 0; i < kBatchSize; i++) {
uint8_t label = buf[i * (kImageSize + kLabelSize)];
std::vector<uint8_t> image;
image.reserve(kImageSize);
std::copy(
&buf[i * (kImageSize + kLabelSize) + 1],
&buf[i * (kImageSize + kLabelSize) + kImageSize],
std::back_inserter(image));
labels.push_back(label);
auto image_tensor =
torch::from_blob(image.data(), {kImageChannels, kImageDim, kImageDim}, torch::TensorOptions().dtype(torch::kU8))
.to(torch::kF32);
images.push_back(image_tensor);
}

auto labels_tensor =
torch::from_blob(labels.data(), {kBatchSize}, torch::TensorOptions().dtype(torch::kU8)).to(torch::kF32);
assert(labels_tensor.size(0) == kBatchSize);

auto images_tensor = torch::stack(images);
assert(images_tensor.size(0) == kBatchSize);

return std::make_pair(images_tensor, labels_tensor);
}

std::pair<torch::Tensor, torch::Tensor> read_train_data(const std::string& root) {
std::vector<torch::Tensor> images, targets;
for (uint32_t i = 1; i <= 5; i++) {
std::stringstream ss;
ss << root << '/' << kTrainFilenamePrefix << i << ".bin";
auto batch = read_batch(ss.str());
images.push_back(batch.first);
targets.push_back(batch.second);
}

torch::Tensor image_tensor =
std::accumulate(++images.begin(), images.end(), *images.begin(), [&](torch::Tensor a, torch::Tensor b) {
return torch::cat({a, b}, 0);
});
torch::Tensor target_tensor =
std::accumulate(++targets.begin(), targets.end(), *targets.begin(), [&](torch::Tensor a, torch::Tensor b) {
return torch::cat({a, b}, 0);
});

return std::make_pair(image_tensor, target_tensor);
}

std::pair<torch::Tensor, torch::Tensor> read_test_data(const std::string& root) {
std::stringstream ss;
ss << root << '/' << kTestFilename;
return read_batch(ss.str());
}
} // namespace

CIFAR10::CIFAR10(const std::string& root, Mode mode) : mode_(mode) {
std::pair<torch::Tensor, torch::Tensor> data;
if (mode_ == Mode::kTrain) {
data = read_train_data(root);
} else {
data = read_test_data(root);
}

images_ = std::move(data.first);
targets_ = std::move(data.second);
assert(images_.sizes()[0] == images_.sizes()[0]);
}

torch::data::Example<> CIFAR10::get(size_t index) {
return {images_[index], targets_[index]};
}

c10::optional<size_t> CIFAR10::size() const {
return images_.size(0);
}

bool CIFAR10::is_train() const noexcept {
return mode_ == Mode::kTrain;
}

const torch::Tensor& CIFAR10::images() const {
return images_;
}

const torch::Tensor& CIFAR10::targets() const {
return targets_;
}

CIFAR10&& CIFAR10::use_subset(int64_t new_size) {
assert(new_size <= images_.sizes()[0]);
images_ = images_.slice(0, 0, new_size);
targets_ = targets_.slice(0, 0, new_size);
return std::move(*this);
}

} // namespace datasets
45 changes: 45 additions & 0 deletions cpp/int8/datasets/cifar10.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
#pragma once

#include "torch/data/datasets/base.h"
#include "torch/data/example.h"
#include "torch/types.h"

#include <cstddef>
#include <string>

namespace datasets {
// The CIFAR10 Dataset
class CIFAR10 : public torch::data::datasets::Dataset<CIFAR10> {
public:
// The mode in which the dataset is loaded
enum class Mode { kTrain, kTest };

// Loads CIFAR10 from un-tarred file
// Dataset can be found
// https://www.cs.toronto.edu/~kriz/cifar-10-binary.tar.gz Root path should be
// the directory that contains the content of tarball
explicit CIFAR10(const std::string& root, Mode mode = Mode::kTrain);

// Returns the pair at index in the dataset
torch::data::Example<> get(size_t index) override;

// The size of the dataset
c10::optional<size_t> size() const override;

// The mode the dataset is in
bool is_train() const noexcept;

// Returns all images stacked into a single tensor
const torch::Tensor& images() const;

// Returns all targets stacked into a single tensor
const torch::Tensor& targets() const;

// Trims the dataset to the first n pairs
CIFAR10&& use_subset(int64_t new_size);

private:
Mode mode_;
torch::Tensor images_, targets_;
};
} // namespace datasets
22 changes: 22 additions & 0 deletions cpp/int8/ptq/BUILD
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
package(default_visibility = ["//visibility:public"])

cc_binary(
name = "ptq",
srcs = [
"main.cpp",
],
copts = [
"-pthread",
],
linkopts = [
"-lpthread",
],
deps = [
"//cpp/api:trtorch",
"//cpp/int8/benchmark",
"//cpp/int8/datasets:cifar10",
"@libtorch",
"@libtorch//:caffe2",
"@tensorrt//:nvinfer",
],
)
Loading

0 comments on commit d8f5d29

Please sign in to comment.