diff --git a/cpp/api/BUILD b/cpp/api/BUILD index fc0bb75408..d396d1690a 100644 --- a/cpp/api/BUILD +++ b/cpp/api/BUILD @@ -12,6 +12,7 @@ cc_library( "src/extra_info.cpp", "src/logging.cpp", "src/trtorch.cpp", + "src/ptq.cpp" ], deps = [ "//core", diff --git a/cpp/api/include/trtorch/ptq.h b/cpp/api/include/trtorch/ptq.h index 6c5b09f9b0..afae26a85c 100644 --- a/cpp/api/include/trtorch/ptq.h +++ b/cpp/api/include/trtorch/ptq.h @@ -6,6 +6,8 @@ #include #include +#include "trtorch/logging.h" + #ifndef DOXYGEN_SHOULD_SKIP_THIS namespace nvinfer1 { class IInt8Calibrator; @@ -13,9 +15,12 @@ class IInt8EntropyCalibrator2; } namespace torch { -namespace data { -template -class Iterator; +class Tensor; +} + +namespace trtorch { +namespace ptq { +bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data); } } #endif //DOXYGEN_SHOULD_SKIP_THIS @@ -45,7 +50,12 @@ class Int8Calibrator : Algorithm { * @param use_cache : bool - Whether to use the cache (if it exists) */ Int8Calibrator(DataLoaderUniquePtr dataloader, const std::string& cache_file_path, bool use_cache) - : dataloader_(dataloader.get()), it_(dataloader_->end()), cache_file_path_(cache_file_path), use_cache_(use_cache) {} + : dataloader_(dataloader.get()), cache_file_path_(cache_file_path), use_cache_(use_cache) { + for (auto batch : *dataloader_) { + batched_data_.push_back(batch.data); + } + it_ = batched_data_.begin(); + } /** * @brief Get the Batch Size for the next batch (always 1 due to issues with TRT and explicit batch) @@ -70,26 +80,15 @@ class Int8Calibrator : Algorithm { * @return false - There is not a new batch for the calibrator to consume */ bool getBatch(void* bindings[], const char* names[], int nbBindings) override { - // HACK: doesnt seem like the first try in the initializer list works - if (! it_created_) { - it_ = dataloader_->begin(); - it_created_ = true; - } - - if (it_ == dataloader_->end()) { + if (it_ != batched_data_.end()) { + auto status = get_batch_impl(bindings, names, nbBindings, *it_); + it_ = ++it_; + return status; + } else { + // Reset iterator if incase calibrator is going to be used again + it_ = batched_data_.begin(); return false; } - - auto batch = *it_; - - for (int i = 0; i < nbBindings; i++) { - auto data = batch.data; - data = data.to(at::kCUDA).contiguous(); - bindings[i] = data.data_ptr(); - } - - it_ = ++it_; - return true; } /** @@ -151,8 +150,6 @@ class Int8Calibrator : Algorithm { private: /// Pointer to the dataloader DataLoader* dataloader_; - /// Iterator used to traverse the dataloader - torch::data::Iterator it_; /// Path to cache file const std::string& cache_file_path_; /// Size of cache @@ -161,10 +158,11 @@ class Int8Calibrator : Algorithm { bool use_cache_; /// Cache data std::vector cache_; - /// If the iterator has been created, DataLoaders can only have 1 live iterator, - /// due to some issues this cannot be created at construction, so it is set in the first - /// batch, controlled by this flag - bool it_created_ = false; + /// Batched Data + std::vector batched_data_; + /// Iterator to move through dataset + std::vector::iterator it_; + }; /** diff --git a/cpp/api/src/ptq.cpp b/cpp/api/src/ptq.cpp new file mode 100644 index 0000000000..ec2223a821 --- /dev/null +++ b/cpp/api/src/ptq.cpp @@ -0,0 +1,16 @@ +#include "torch/torch.h" +#include "trtorch/ptq.h" + +namespace trtorch { +namespace ptq { + +bool get_batch_impl(void* bindings[], const char* names[], int nbBindings, torch::Tensor& data) { + for (int i = 0; i < nbBindings; i++) { + data = data.to(at::kCUDA).contiguous(); + bindings[i] = data.data_ptr(); + } + return true; +} + +} // namespace ptq +} // namespace trtorch \ No newline at end of file diff --git a/cpp/benchmark/README.md b/cpp/benchmark/README.md index 73041c9ff3..1c45bc9fbe 100644 --- a/cpp/benchmark/README.md +++ b/cpp/benchmark/README.md @@ -1,6 +1,6 @@ # Benchmarking -This is a quick benchmarking application for TRTorch. It lets you run supported TorchScript modules both in JIT and TRT and returns the average runtime and throughput. +This is a quick benchmarking application for TRTorch. It lets you run supported TorchScript modules both in JIT and TRT and returns the average runtime and throughput. ## Compilation / Usage @@ -20,7 +20,7 @@ bazel run //cpp/benchmark --cxxopt="-DNDEBUG" --cxxopt="-DJIT" --cxxopt="-DTRT" ### Options -You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These options are controlled by preprocessor directives. +You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These options are controlled by preprocessor directives. - To enable JIT profiling, add the argument `--cxxopt="-DJIT"` @@ -28,4 +28,6 @@ You can run a module with JIT or TRT via TRTorch in either FP32 or FP16. These o - To enable FP16 execution, add the argument `--cxxopt="-DHALF"` +- To also save the TRT engine, add the argument `--cxxopt="-DSAVE_ENGINE"` + > It's suggested to also define `--cxxopt="-DNDEBUG"` to supress debug information diff --git a/cpp/benchmark/main.cpp b/cpp/benchmark/main.cpp index 05c84cf7ac..e73f1da4e8 100644 --- a/cpp/benchmark/main.cpp +++ b/cpp/benchmark/main.cpp @@ -105,15 +105,6 @@ int main(int argc, const char* argv[]) { mod.to(at::kCUDA); -#ifdef HALF - mod.to(torch::kHalf); - for (auto layer : mod.named_modules()) { - if (layer.name.find(".bn") != std::string::npos) { - layer.value.to(torch::kFloat); - } - } -#endif - std::vector> dims; for (int i = 2; i < argc; i++) { auto arg = std::string(argv[i]); @@ -129,23 +120,42 @@ int main(int argc, const char* argv[]) { at::globalContext().setBenchmarkCuDNN(true); -#ifdef JIT - auto jit_runtimes = benchmark_module(mod, dims[0]); - print_avg_std_dev("JIT", jit_runtimes, dims[0][0]); -#endif - #ifdef TRT auto extra_info = trtorch::ExtraInfo(dims); - extra_info.workspace_size = 1 << 24; + extra_info.workspace_size = 1 << 20; #ifdef HALF - extra_info.op_precision = at::kHalf; + extra_info.op_precision = torch::kF16; #endif auto trt_mod = trtorch::CompileGraph(mod, extra_info); + +#ifdef SAVE_ENGINE + std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info); + std::ofstream out("/tmp/engine_converted_from_jit.trt"); + out << engine; + out.close(); +#endif + auto trt_runtimes = benchmark_module(trt_mod, dims[0]); print_avg_std_dev("JIT/TRT", trt_runtimes, dims[0][0]); #endif + +#ifdef HALF + mod.to(torch::kHalf); + for (auto layer : mod.named_modules()) { + if (layer.name.find(".bn") != std::string::npos) { + layer.value.to(torch::kFloat); + } + } +#endif + +#ifdef JIT + auto jit_runtimes = benchmark_module(mod, dims[0]); + print_avg_std_dev("JIT", jit_runtimes, dims[0][0]); +#endif + std::cout << "ok\n"; } diff --git a/cpp/ptq/BUILD b/cpp/ptq/BUILD index fd8261c08b..d190b0e7db 100644 --- a/cpp/ptq/BUILD +++ b/cpp/ptq/BUILD @@ -4,9 +4,9 @@ cc_binary( name = "ptq", srcs = [ "main.cpp", - "timer.h" ], deps = [ + "//cpp/ptq/benchmark", "//cpp/ptq/datasets:cifar10", "@libtorch//:libtorch", "@libtorch//:caffe2", diff --git a/cpp/ptq/benchmark/BUILD b/cpp/ptq/benchmark/BUILD new file mode 100644 index 0000000000..8647b240da --- /dev/null +++ b/cpp/ptq/benchmark/BUILD @@ -0,0 +1,17 @@ +package(default_visibility = ["//visibility:public"]) + +cc_library( + name = "benchmark", + hdrs = [ + "benchmark.h" + ], + srcs = [ + "benchmark.cpp", + "timer.h" + ], + deps = [ + "@libtorch//:libtorch", + "@libtorch//:caffe2", + "//cpp/api:trtorch" + ], +) diff --git a/cpp/ptq/benchmark/benchmark.cpp b/cpp/ptq/benchmark/benchmark.cpp new file mode 100644 index 0000000000..b88ec7583c --- /dev/null +++ b/cpp/ptq/benchmark/benchmark.cpp @@ -0,0 +1,70 @@ +#include "torch/script.h" +#include "torch/torch.h" +#include "ATen/Context.h" +#include "c10/cuda/CUDACachingAllocator.h" +#include "trtorch/trtorch.h" +#include "cuda_runtime_api.h" + +#include "timer.h" + +#define NUM_WARMUP_RUNS 20 +#define NUM_RUNS 100 + +// Benchmaking code +void print_avg_std_dev(std::string type, std::vector& runtimes, uint64_t batch_size) { + float avg_runtime = std::accumulate(runtimes.begin(), runtimes.end(), 0.0) / runtimes.size(); + float fps = (1000.f / avg_runtime) * batch_size; + std::cout << "[" << type << "]: batch_size: " << batch_size << "\n Average latency: " << avg_runtime << " ms\n Average FPS: " << fps << " fps" < rt_diff(runtimes.size()); + std::transform(runtimes.begin(), runtimes.end(), rt_diff.begin(), [avg_runtime](float x) { return x - avg_runtime; }); + float rt_sq_sum = std::inner_product(rt_diff.begin(), rt_diff.end(), rt_diff.begin(), 0.0); + float rt_std_dev = std::sqrt(rt_sq_sum / runtimes.size()); + + std::vector fps_diff(runtimes.size()); + std::transform(runtimes.begin(), runtimes.end(), fps_diff.begin(), [fps, batch_size](float x) { return ((1000.f / x) * batch_size) - fps; }); + float fps_sq_sum = std::inner_product(fps_diff.begin(), fps_diff.end(), fps_diff.begin(), 0.0); + float fps_std_dev = std::sqrt(fps_sq_sum / runtimes.size()); + std::cout << " Latency Standard Deviation: " << rt_std_dev << "\n FPS Standard Deviation: " << fps_std_dev << "\n(excluding initial warmup runs)" << std::endl; +} + +std::vector benchmark_module(torch::jit::script::Module& mod, std::vector shape) { + auto execution_timer = timers::PreciseCPUTimer(); + std::vector execution_runtimes; + + for (uint64_t i = 0; i < NUM_WARMUP_RUNS; i++) { + std::vector inputs_ivalues; + auto in = at::rand(shape, {at::kCUDA}); +#ifdef HALF + in = in.to(torch::kHalf); +#endif + inputs_ivalues.push_back(in.clone()); + + cudaDeviceSynchronize(); + mod.forward(inputs_ivalues); + cudaDeviceSynchronize(); + + } + + for (uint64_t i = 0; i < NUM_RUNS; i++) { + std::vector inputs_ivalues; + auto in = at::rand(shape, {at::kCUDA}); +#ifdef HALF + in = in.to(torch::kHalf); +#endif + inputs_ivalues.push_back(in.clone()); + cudaDeviceSynchronize(); + + execution_timer.start(); + mod.forward(inputs_ivalues); + cudaDeviceSynchronize(); + execution_timer.stop(); + + auto time = execution_timer.milliseconds(); + execution_timer.reset(); + execution_runtimes.push_back(time); + + c10::cuda::CUDACachingAllocator::emptyCache(); + } + return execution_runtimes; +} diff --git a/cpp/ptq/benchmark/benchmark.h b/cpp/ptq/benchmark/benchmark.h new file mode 100644 index 0000000000..3c11833ab3 --- /dev/null +++ b/cpp/ptq/benchmark/benchmark.h @@ -0,0 +1,4 @@ +#pragma once + +void print_avg_std_dev(std::string type, std::vector& runtimes, uint64_t batch_size); +std::vector benchmark_module(torch::jit::script::Module& mod, std::vector shape); diff --git a/cpp/ptq/timer.h b/cpp/ptq/benchmark/timer.h similarity index 100% rename from cpp/ptq/timer.h rename to cpp/ptq/benchmark/timer.h diff --git a/cpp/ptq/main.cpp b/cpp/ptq/main.cpp index 9e806b5e59..ed40398826 100644 --- a/cpp/ptq/main.cpp +++ b/cpp/ptq/main.cpp @@ -5,31 +5,34 @@ #include "NvInfer.h" #include "datasets/cifar10.h" -#include "timer.h" +#include "benchmark/benchmark.h" #include #include #include #include -int main(int argc, const char* argv[]) { - if (argc < 3) { - std::cerr << "usage: ptq \n"; - return -1; - } +namespace F = torch::nn::functional; - torch::jit::Module mod; - try { - /// Deserialize the ScriptModule from a file using torch::jit::load(). - mod = torch::jit::load(argv[1]); - } - catch (const c10::Error& e) { - std::cerr << "error loading the model\n"; - return -1; +// Actual PTQ application code +struct Resize : public torch::data::transforms::TensorTransform { + Resize(std::vector new_size) + : new_size_(new_size) {} + + torch::Tensor operator()(torch::Tensor input) { + input = input.unsqueeze(0); + auto upsampled = F::interpolate(input, F::InterpolateFuncOptions() + .size(new_size_) + .align_corners(false) + .mode(torch::kBilinear)); + return upsampled.squeeze(0); } - /// Create the calibration dataset - const std::string data_dir = std::string(argv[2]); + + std::vector new_size_; +}; + +torch::jit::Module compile_int8_model(const std::string& data_dir, torch::jit::Module& mod) { auto calibration_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) .use_subset(320) .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, @@ -58,14 +61,49 @@ int main(int argc, const char* argv[]) { mod.eval(); +#ifdef SAVE_ENGINE + std::cout << "Compiling graph to save as TRT engine (/tmp/engine_converted_from_jit.trt)" << std::endl; + auto engine = trtorch::ConvertGraphToTRTEngine(mod, "forward", extra_info); + std::ofstream out("/tmp/engine_converted_from_jit.trt"); + out << engine; + out.close(); +#endif + + std::cout << "Compiling and quantizing module" << std::endl; + auto trt_mod = trtorch::CompileGraph(mod, extra_info); + return std::move(trt_mod); +} + +int main(int argc, const char* argv[]) { + at::globalContext().setBenchmarkCuDNN(true); + + if (argc < 3) { + std::cerr << "usage: ptq \n"; + return -1; + } + + torch::jit::Module mod; + try { + /// Deserialize the ScriptModule from a file using torch::jit::load(). + mod = torch::jit::load(argv[1]); + } + catch (const c10::Error& e) { + std::cerr << "error loading the model\n"; + return -1; + } + + /// Create the calibration dataset + const std::string data_dir = std::string(argv[2]); + auto trt_mod = compile_int8_model(data_dir, mod); + /// Dataloader moved into calibrator so need another for inference auto eval_dataset = datasets::CIFAR10(data_dir, datasets::CIFAR10::Mode::kTest) .map(torch::data::transforms::Normalize<>({0.4914, 0.4822, 0.4465}, - {0.2023, 0.1994, 0.2010})) + {0.2023, 0.1994, 0.2010})) .map(torch::data::transforms::Stack<>()); - auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), torch::data::DataLoaderOptions() - .batch_size(32) - .workers(2)); + auto eval_dataloader = torch::data::make_data_loader(std::move(eval_dataset), + torch::data::DataLoaderOptions().batch_size(32) + .workers(2)); /// Check the FP32 accuracy in JIT float correct = 0.0, total = 0.0; @@ -81,10 +119,6 @@ int main(int argc, const char* argv[]) { } std::cout << "Accuracy of JIT model on test set: " << 100 * (correct / total) << "%" << std::endl; - /// Compile Graph - std::cout << "Compiling and quantizing module" << std::endl; - auto trt_mod = trtorch::CompileGraph(mod, extra_info); - /// Check the INT8 accuracy in TRT correct = 0.0; total = 0.0; @@ -116,19 +150,13 @@ int main(int argc, const char* argv[]) { std::cout << "Accuracy of quantized model on test set: " << 100 * (correct / total) << "%" << std::endl; /// Time execution in JIT-FP32 and TRT-INT8 - auto execution_timer = timers::PreciseCPUTimer(); - auto images = (*(*eval_dataloader).begin()).data.to(torch::kCUDA); + std::vector> dims = {{32, 3, 32, 32}}; - execution_timer.start(); - mod.forward({images}); - execution_timer.stop(); - std::cout << "Latency of JIT model FP32 (Batch Size 32): " << execution_timer.milliseconds() << "ms" << std::endl; + auto jit_runtimes = benchmark_module(mod, dims[0]); + print_avg_std_dev("JIT model FP32", jit_runtimes, dims[0][0]); - execution_timer.reset(); + auto trt_runtimes = benchmark_module(trt_mod, dims[0]); + print_avg_std_dev("TRT quantized model", trt_runtimes, dims[0][0]); - execution_timer.start(); - trt_mod.forward({images}); - execution_timer.stop(); - std::cout << "Latency of quantized model (Batch Size 32): " << execution_timer.milliseconds() << "ms" << std::endl; }