Skip to content

Commit 339919d

Browse files
committed
chore: Fix lowering, comment CSE pass
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 5df502a commit 339919d

File tree

5 files changed

+38
-15
lines changed

5 files changed

+38
-15
lines changed

core/conversion/conversionctx/ConversionCtx.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,10 @@ ConversionCtx::ConversionCtx(BuilderSettings build_settings)
7070
cfg->setFlag(nvinfer1::BuilderFlag::kFP16);
7171
}
7272
input_type = nvinfer1::DataType::kFLOAT;
73-
// TRTORCH_CHECK(
74-
// settings.calibrator != nullptr,
75-
// "Requested inference in INT8 but no calibrator provided, set the ptq_calibrator field in the CompileSpec
76-
// struct with your calibrator");
77-
// cfg->setInt8Calibrator(settings.calibrator);
73+
// Networks trained with Quantization aware training approach don't need a calibrator as they have Q/DQ nodes.
74+
if (!settings.calibrator){
75+
LOG_WARNING("Int8 precision has been enabled but no calibrator provided. This assumes the network has Q/DQ nodes obtained from Quantization aware training. For more details, refer to https://docs.nvidia.com/deeplearning/tensorrt/developer-guide/index.html#work-with-qat-networks");
76+
}
7877
break;
7978
case nvinfer1::DataType::kFLOAT:
8079
default:

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ cc_library(
4747
"impl/matrix_multiply.cpp",
4848
"impl/normalize.cpp",
4949
"impl/pooling.cpp",
50+
"impl/quantization.cpp",
5051
"impl/reduce.cpp",
5152
"impl/replication_pad.cpp",
5253
"impl/select.cpp",

core/conversion/converters/impl/linear.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,27 @@ auto linear_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().patt
6464
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
6565
return true;
6666
}
67+
68+
auto w_tensor = args[1].IValue()->toTensor();
69+
Weights w = Weights(ctx, w_tensor);
70+
71+
nvinfer1::ILayer* new_layer;
72+
if (!args[2].IValue()->isNone()) {
73+
Weights b(ctx, args[2].IValue()->toTensor());
74+
new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, b.data);
75+
} else {
76+
LOG_DEBUG("There is no bias for the linear layer");
77+
new_layer = ctx->net->addFullyConnected(*in, w.num_output_maps, w.data, Weights().data);
78+
}
79+
80+
TRTORCH_CHECK(new_layer, "Unable to create linear layer from node: " << *n);
81+
82+
new_layer->setName(util::node_info(n).c_str());
83+
auto out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], new_layer->getOutput(0));
84+
85+
LOG_DEBUG("Output tensor shape: " << out_tensor->getDimensions());
86+
87+
return true;
6788
}});
6889
} // namespace
6990
} // namespace impl

core/lowering/lowering.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
#include "core/lowering/lowering.h"
2-
#include <torch/csrc/jit/passes/inliner.h>
3-
#include "core/lowering/passes/passes.h"
4-
#include "core/util/prelude.h"
51
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
62
#include "torch/csrc/jit/passes/create_functional_graphs.h"
73
#include "torch/csrc/jit/passes/dead_code_elimination.h"
@@ -14,6 +10,10 @@
1410
#include "torch/csrc/jit/passes/peephole.h"
1511
#include "torch/csrc/jit/passes/remove_mutation.h"
1612

13+
#include "core/lowering/lowering.h"
14+
#include "core/lowering/passes/passes.h"
15+
#include "core/util/prelude.h"
16+
1717
namespace trtorch {
1818
namespace core {
1919
namespace lowering {
@@ -42,9 +42,10 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
4242
passes::Conv3DToConvolution(g);
4343
passes::FuseAddMMBranches(g);
4444
passes::RemoveBNDimCheck(g);
45-
torch::jit::EliminateCommonSubexpression(g);
45+
LOG_INFO("====PRE CSE =====" << *g);
46+
// torch::jit::EliminateCommonSubexpression(g);
47+
LOG_INFO("====POST CSE =====" << *g);
4648
// torch::jit::UnrollLoops(g);
47-
torch::jit::EliminateCommonSubexpression(g);
4849
passes::UnpackAddMM(g);
4950
// passes::UnpackBatchNorm(g);
5051
passes::UnpackLogSoftmax(g);
@@ -65,18 +66,17 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
6566
std::string method_name) {
6667
auto lowered_mod = mod; // LowerModule(mod);
6768
auto g = lowered_mod.get_method(method_name).graph();
68-
Inline(*g);
69-
LOG_INFO("========INLINING : " << *g);
69+
LOG_GRAPH(*g);
7070

7171
// Go through TRTorch Lowering to reformat graph to be conversion friendly
7272
// and also segment for accelerators and executors (TRT-DLA, TRT-GPU, PYT)
7373
LOG_GRAPH("TRTorch Graph Lowering");
74-
lowering::LowerGraph(g);
74+
// lowering::LowerGraph(g);
7575

7676
LOG_GRAPH("LibTorch Lowering");
7777
auto graph_and_ivalues = torch::jit::LowerGraph(*g, lowered_mod._ivalue());
78+
lowering::LowerGraph(graph_and_ivalues.first);
7879
// Is this necessary?
79-
8080
lowering::LowerBlock(g->block());
8181

8282
return graph_and_ivalues;

core/util/jit_util.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ inline std::string node_info(const torch::jit::Node* n) {
1313
std::stringstream ss;
1414
ss << *n;
1515
std::string node_info = ss.str();
16+
// Nodes in torchscript graph have file name and line numbers commented for every node. Remove that when returning a node name for easier readability.
17+
node_info = node_info.substr(0, node_info.find("#", 0));
1618
node_info.erase(std::remove(node_info.begin(), node_info.end(), '\n'), node_info.end());
1719
return node_info;
1820
}

0 commit comments

Comments
 (0)