Skip to content

Commit 54f08f9

Browse files
committed
Merge remote-tracking branch 'origin/master' into trt_8
2 parents 169c5bc + bdaacf1 commit 54f08f9

21 files changed

+357
-59
lines changed

core/compiler.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ torch::jit::script::Module CompileGraphWithFallback(const torch::jit::script::Mo
182182
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
183183
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
184184
for (const torch::jit::script::Method& method : mod.get_methods()) {
185-
// Don't convert hidden methods
186-
if (method.name().rfind("_", 0)) {
185+
// Compile only forward methods. forward method contains the entire graph.
186+
if (method.name().compare("forward") == 0) {
187187
auto new_g = std::make_shared<torch::jit::Graph>();
188188
auto graph_and_parameters = lowering::Lower(mod, method.name());
189189

@@ -256,8 +256,8 @@ torch::jit::script::Module CompileGraph(const torch::jit::script::Module& mod, C
256256
torch::jit::script::Module new_mod(mod._ivalue()->name() + "_trt");
257257
std::vector<std::shared_ptr<torch::jit::Graph>> graphs;
258258
for (const torch::jit::script::Method& method : mod.get_methods()) {
259-
// Don't convert hidden methods
260-
if (method.name().rfind("_", 0)) {
259+
// Compile only forward methods. forward method contains the entire graph.
260+
if (method.name().compare("forward") == 0) {
261261
auto engine = ConvertGraphToTRTEngine(mod, method.name(), cfg);
262262
auto new_g = std::make_shared<torch::jit::Graph>();
263263
AddEngineToGraph(new_mod, new_g, engine);

core/conversion/converters/impl/layer_norm.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -117,12 +117,31 @@ auto layer_norm_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().
117117
}
118118

119119
auto power = Weights(ctx, at::ones(expand_size));
120-
auto scale_nd = ctx->net->addScaleNd(
121-
*div_out, nvinfer1::ScaleMode::kELEMENTWISE, beta_weights.data, gamma_weights.data, power.data, 1);
122-
scale_nd->setName((util::node_info(n) + "_scale_nd").c_str());
123-
auto scale_nd_out = scale_nd->getOutput(0);
124120

125-
ctx->AssociateValueAndTensor(n->outputs()[0], scale_nd_out);
121+
auto gamma_tensor = ctx->net->addConstant(gamma_weights.shape, gamma_weights.data)->getOutput(0);
122+
auto scale_l = add_elementwise(
123+
ctx, nvinfer1::ElementWiseOperation::kPROD, div_out, gamma_tensor, (util::node_info(n) + "_scale").c_str());
124+
125+
auto beta_tensor = ctx->net->addConstant(beta_weights.shape, beta_weights.data)->getOutput(0);
126+
auto shift_l = add_elementwise(
127+
ctx,
128+
nvinfer1::ElementWiseOperation::kSUM,
129+
scale_l->getOutput(0),
130+
beta_tensor,
131+
(util::node_info(n) + "_shift").c_str());
132+
133+
auto power_tensor = ctx->net->addConstant(power.shape, power.data)->getOutput(0);
134+
auto power_l = add_elementwise(
135+
ctx,
136+
nvinfer1::ElementWiseOperation::kPOW,
137+
shift_l->getOutput(0),
138+
power_tensor,
139+
(util::node_info(n) + "_power").c_str());
140+
141+
power_l->setName((util::node_info(n) + "_scale_nd").c_str());
142+
auto power_l_out = power_l->getOutput(0);
143+
144+
ctx->AssociateValueAndTensor(n->outputs()[0], power_l_out);
126145
return true;
127146
}});
128147

core/conversion/converters/impl/matrix_multiply.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "core/conversion/converters/converter_util.h"
12
#include "core/conversion/converters/converters.h"
23
#include "core/util/prelude.h"
34

@@ -13,10 +14,14 @@ auto mm_registrations TRTORCH_UNUSED =
1314
.pattern({"aten::matmul(Tensor self, Tensor other) -> (Tensor)",
1415
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
1516
auto self = args[0].ITensorOrFreeze(ctx);
16-
LOG_DEBUG("self tensor shape: " << self->getDimensions());
17-
1817
auto other = args[1].ITensorOrFreeze(ctx);
19-
LOG_DEBUG("other tensor shape: " << other->getDimensions());
18+
// Ensure self and other tensors have same nbDims by expanding the dimensions (from 0 axis) if
19+
// necessary.
20+
if (self->getDimensions().nbDims < other->getDimensions().nbDims) {
21+
self = addPadding(ctx, n, self, other->getDimensions().nbDims, false, false);
22+
} else {
23+
other = addPadding(ctx, n, other, self->getDimensions().nbDims, false, false);
24+
}
2025

2126
auto mm_layer = ctx->net->addMatrixMultiply(
2227
*self, nvinfer1::MatrixOperation::kNONE, *other, nvinfer1::MatrixOperation::kNONE);
@@ -73,4 +78,4 @@ auto mm_registrations TRTORCH_UNUSED =
7378
} // namespace converters
7479
} // namespace conversion
7580
} // namespace core
76-
} // namespace trtorch
81+
} // namespace trtorch

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ void LowerBlock(torch::jit::Block* b) {
2525
}
2626

2727
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g) {
28+
passes::UnpackHardSwish(g);
2829
torch::jit::EliminateRedundantGuards(g);
2930
torch::jit::RemoveListMutation(g);
3031
torch::jit::RemoveTensorMutation(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"unpack_addmm.cpp",
2525
"unpack_batch_norm.cpp",
2626
"unpack_log_softmax.cpp",
27+
"unpack_hardswish.cpp"
2728
],
2829
hdrs = [
2930
"passes.h",

core/lowering/passes/linear_to_addmm.cpp

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,55 @@
1-
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
1+
2+
#include <torch/csrc/jit/runtime/operator.h>
3+
#include "torch/csrc/jit/ir/alias_analysis.h"
4+
#include "torch/csrc/jit/jit_log.h"
5+
#include "torch/csrc/jit/passes/constant_propagation.h"
6+
#include "torch/csrc/jit/passes/dead_code_elimination.h"
7+
#include "torch/csrc/jit/passes/guard_elimination.h"
8+
#include "torch/csrc/jit/passes/peephole.h"
9+
#include "torch/csrc/jit/runtime/graph_executor.h"
210

311
#include "core/util/prelude.h"
12+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
413

514
namespace trtorch {
615
namespace core {
716
namespace lowering {
817
namespace passes {
918

19+
void replaceLinearWithBiasNonePattern(std::shared_ptr<torch::jit::Graph> graph) {
20+
// Define the decomposition function for aten::linear for the case where bias (mat2) is None.
21+
static torch::jit::CompilationUnit decompose_funcs(R"SCRIPT(
22+
def linear(self: Tensor, mat1: Tensor, mat2: Tensor):
23+
return torch.matmul(self, mat1.t())
24+
)SCRIPT");
25+
26+
// Iterate through nodes and search for aten::linear nodes where bias is not a Tensor (includes bias=None case)
27+
auto block = graph->block();
28+
for (auto it = block->nodes().begin(); it != block->nodes().end(); it++) {
29+
auto n = *it;
30+
if (n->kind().toQualString() == std::string("aten::linear")) {
31+
auto input_values = n->inputs();
32+
// input_values[2] is the bias. If none, replace it with the decomposed linear graph.
33+
if (input_values[2]->type()->isSubtypeOf(c10::TensorType::get())) {
34+
continue;
35+
} else {
36+
torch::jit::WithInsertPoint guard(*it);
37+
std::shared_ptr<torch::jit::Graph> d_graph = decompose_funcs.get_function("linear").graph();
38+
torch::jit::Value* new_output = insertGraph(*it->owningGraph(), *d_graph, it->inputs()).at(0);
39+
new_output->setType(it->output()->type());
40+
it->output()->replaceAllUsesWith(new_output);
41+
it.destroyCurrent();
42+
}
43+
}
44+
}
45+
}
46+
1047
void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
1148
// TensorRT implicitly adds a flatten layer infront of FC layers if necessary
1249
std::string flatten_linear_pattern = R"IR(
1350
graph(%input, %weight, %bias):
1451
%res = aten::linear(%input, %weight, %bias)
1552
return (%res))IR";
16-
std::string flatten_linear_bias_none_pattern = R"IR(
17-
graph(%input, %weight):
18-
%bias: Tensor? = prim::Constant()
19-
%res = aten::linear(%input, %weight, %bias)
20-
return (%res))IR";
2153

2254
std::string fused_linear = R"IR(
2355
graph(%input, %weight_t, %bias):
@@ -27,20 +59,13 @@ void LinearToAddMM(std::shared_ptr<torch::jit::Graph>& graph) {
2759
%b_f: Tensor = trt::const(%bias)
2860
%out: Tensor = aten::add(%b_f, %mm, %1)
2961
return (%out))IR";
30-
std::string fused_linear_bias_none = R"IR(
31-
graph(%input, %weight_t):
32-
%weight = aten::t(%weight_t)
33-
%mm: Tensor = aten::matmul(%input, %weight)
34-
return (%mm))IR";
62+
63+
// First find and replace aten::linear nodes with non-tensor bias values.
64+
replaceLinearWithBiasNonePattern(graph);
3565

3666
torch::jit::SubgraphRewriter flatten_linear_to_linear;
3767
flatten_linear_to_linear.RegisterRewritePattern(flatten_linear_pattern, fused_linear);
3868
flatten_linear_to_linear.runOnGraph(graph);
39-
40-
torch::jit::SubgraphRewriter flatten_linear_bias_none_to_linear;
41-
flatten_linear_bias_none_to_linear.RegisterRewritePattern(flatten_linear_bias_none_pattern, fused_linear_bias_none);
42-
flatten_linear_bias_none_to_linear.runOnGraph(graph);
43-
LOG_GRAPH("Post linear to addmm: " << *graph);
4469
}
4570

4671
} // namespace passes

core/lowering/passes/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
2121
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);
2222
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
2323
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
24+
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
2425

2526
} // namespace passes
2627
} // namespace lowering
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string hardswish_pattern = R"IR(
12+
graph(%input):
13+
%result = aten::hardswish(%input)
14+
return (%result))IR";
15+
16+
std::string hardswish_pattern_inplace = R"IR(
17+
graph(%input):
18+
%result = aten::hardswish_(%input)
19+
return (%result))IR";
20+
21+
std::string new_pattern = R"IR(
22+
graph(%input):
23+
%1 : Scalar = prim::Constant[value=3.]()
24+
%2 : Scalar = prim::Constant[value=1.]()
25+
%3 = aten::add(%input, %1, %2)
26+
%4 : Scalar = prim::Constant[value=0.]()
27+
%5 : Scalar = prim::Constant[value=6.]()
28+
%6 = aten::hardtanh(%3, %4, %5)
29+
%7 = aten::div(%6, %5)
30+
%8 = aten::mul(%input, %7)
31+
return (%8))IR";
32+
33+
torch::jit::SubgraphRewriter rewriter;
34+
rewriter.RegisterRewritePattern(hardswish_pattern, new_pattern);
35+
rewriter.RegisterRewritePattern(hardswish_pattern_inplace, new_pattern);
36+
rewriter.runOnGraph(graph);
37+
38+
LOG_GRAPH("Post unpack hardswish: " << *graph);
39+
}
40+
41+
} // namespace passes
42+
} // namespace lowering
43+
} // namespace core
44+
} // namespace trtorch

core/partitioning/partitioning.cpp

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ namespace core {
1010
namespace partitioning {
1111

1212
struct usage_info {
13-
int produce_id = -1;
14-
std::vector<int> torch_use_id;
15-
std::vector<int> tensorrt_use_id;
13+
size_t produce_id; // id of segmented block which contains a raw value of a given torch::jit::Value
14+
std::vector<size_t> torch_use_id; // ids of segmented blocks which are of type Pytorch
15+
std::vector<size_t> tensorrt_use_id; // ids of segmented blocks which are of type TensorRT
1616
};
1717

1818
inline bool isTensorOrTensorList(torch::jit::Value* val) {
@@ -70,44 +70,54 @@ std::vector<torch::jit::Node*> getDependencyNodes(std::vector<torch::jit::Value*
7070
return stk;
7171
}
7272

73-
std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_block) {
73+
std::vector<SegmentedBlock> segmentBlocksWithNonTensorInputs(SegmentedBlock& seg_block) {
7474
// reconstruct segmented_block if this block requires nonTensor input
7575
std::vector<torch::jit::Value*> nontensor_inputs;
76+
// Gather all non-tensor inputs for this seg_block
7677
for (auto input : seg_block.raw_inputs()) {
7778
if (!isTensorOrTensorList(input)) {
7879
nontensor_inputs.push_back(input);
7980
}
8081
}
81-
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
8282

83+
std::vector<torch::jit::Node*> dependency_nodes = getDependencyNodes(nontensor_inputs);
8384
std::vector<SegmentedBlock> new_seg_blocks;
84-
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, construct only
85-
// one new block
85+
// if current block is kTorch or current block is TensorRT and all dependent nodes are also supported, merge the
86+
// dependency nodes at the beginning of the current segmented_block and return this merged segmented_block
8687
if (seg_block.target() == SegmentedBlock::kTorch || isAllNodesSupported(dependency_nodes)) {
8788
dependency_nodes.insert(dependency_nodes.end(), seg_block.raw_nodes().begin(), seg_block.raw_nodes().end());
8889
new_seg_blocks.emplace_back(seg_block.target(), dependency_nodes);
8990
} else {
9091
// if current block is kTensorRT but the dependency nodes contain unsupported node, then we have to segment again
9192
std::unordered_set<torch::jit::Value*> nontensor_inputs_set(nontensor_inputs.begin(), nontensor_inputs.end());
92-
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes;
93+
std::vector<torch::jit::Node*> tensorrt_nodes, pytorch_nodes(dependency_nodes.begin(), dependency_nodes.end());
94+
9395
bool prev_non_tensor_outputs = false;
9496
for (auto n : seg_block.raw_nodes()) {
95-
// it's a kTorch block if it uses the nonTensor input and the nonTensor input is produced in kTorch block
97+
// Check if the node has non-tensor inputs or if it consumes non-tensor outputs of previous node.
98+
// In these cases, these nodes are placed into a new Pytorch SegmentedBlock. Else, they form a new TensorRT
99+
// SegmentedBlock.
96100
if (containTargetInputs(n, nontensor_inputs_set) || prev_non_tensor_outputs) {
101+
// If tensorrt_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
102+
// TensorRT segmented_block and clear the tensorrt_nodes list to be later used for new TRT segments.
97103
if (!tensorrt_nodes.empty()) {
98104
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
99105
tensorrt_nodes.clear();
100106
}
101107
pytorch_nodes.push_back(n);
102108
prev_non_tensor_outputs = containNonTensorOutputs(n);
103109
} else {
110+
// If pytorch_nodes is not empty, the previous nodes were all tensorrt_nodes. Construct a
111+
// Pytorch segmented_block and clear the pytorch_nodes list to be later used for new Pytorch segments.
104112
if (!pytorch_nodes.empty()) {
105113
new_seg_blocks.emplace_back(SegmentedBlock::kTorch, pytorch_nodes);
106114
pytorch_nodes.clear();
107115
}
108116
tensorrt_nodes.push_back(n);
109117
}
110118
}
119+
120+
// Form the last segmented_block with the left over nodes in tensorrt_nodes or pytorch_nodes correspondingly.
111121
if (!tensorrt_nodes.empty()) {
112122
new_seg_blocks.emplace_back(SegmentedBlock::kTensorRT, tensorrt_nodes);
113123
} else {
@@ -118,7 +128,20 @@ std::vector<SegmentedBlock> injectNodesForNonTensorInputs(SegmentedBlock& seg_bl
118128
}
119129

120130
void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<torch::jit::Graph> g) {
121-
// for NonTensor inputs in TensorRT segments, count the usages on Torch segments and TensorRT segments
131+
// create a list so we can insert SegmentedBlock without losing the iterators
132+
std::list<SegmentedBlock> segmented_blocks_list(segmented_blocks.begin(), segmented_blocks.end());
133+
std::unordered_map<size_t, std::list<SegmentedBlock>::iterator> idx_to_iter;
134+
auto iter = segmented_blocks_list.begin();
135+
for (size_t i = 0; i < segmented_blocks.size(); ++i, ++iter) {
136+
idx_to_iter[i] = iter;
137+
}
138+
139+
// usage_counts is a map which stores non-tensor inputs as keys and the values are indices of segmented blocks which
140+
// have these non-tensor inputs. Iterate through the graph (segmented blocks) from bottom to top. When we find a
141+
// non-tensor input in a segmented block of index "i", store it in the usage_counts map. Now for each non-tensor
142+
// inputs recorded in the usage_counts map, we check if any previous segmented block (segmented block index i goes
143+
// from n-1 to 0) generated/contains this non-tensor input. If so, we set this idx as the produce_id as it produces
144+
// the non-tensor input.
122145
std::unordered_map<torch::jit::Value*, usage_info> usage_counts;
123146
for (int i = segmented_blocks.size() - 1; i >= 0; --i) {
124147
for (auto input : segmented_blocks[i].raw_inputs()) {
@@ -127,36 +150,44 @@ void resolveNonTensorInputs(PartitionedGraph& segmented_blocks, std::shared_ptr<
127150
: usage_counts[input].tensorrt_use_id.push_back(i);
128151
}
129152
}
153+
130154
for (auto& use : usage_counts) {
155+
// Set the produce_id to the segmented block index that contains/produces this non-tensor torch::jit::Value
131156
if (segmented_blocks[i].contain_raw_value(use.first)) {
132157
use.second.produce_id = i;
133158
}
134159
}
135160
}
161+
136162
std::unordered_set<int> updated_segments;
137163
for (auto& use : usage_counts) {
138164
auto use_info = use.second;
139165
// if the segment that produce this nonTensor value is kTensorRT but consumed in kTorch, inject nodes in the first
140-
// kTorch segments
166+
// kTorch segment.
141167
if (segmented_blocks[use_info.produce_id].target() == SegmentedBlock::kTensorRT && !use_info.torch_use_id.empty()) {
142-
int first_torch_id = use_info.torch_use_id.front();
168+
auto first_torch_id = use_info.torch_use_id.front();
143169
if (!updated_segments.count(first_torch_id)) {
144-
auto new_torch_block = injectNodesForNonTensorInputs(segmented_blocks[first_torch_id]).front();
145-
segmented_blocks[first_torch_id] = new_torch_block;
170+
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
171+
// TRTorch doesn't support non-tensor inputs for a module.
172+
auto new_torch_block = segmentBlocksWithNonTensorInputs(segmented_blocks[first_torch_id]).front();
173+
*idx_to_iter[first_torch_id] = new_torch_block;
146174
updated_segments.insert(first_torch_id);
147175
}
148-
} else {
149-
// KTensorRT segments always need to inject nodes for the nonTensor inputs
150-
for (int i : use_info.tensorrt_use_id) {
151-
if (!updated_segments.count(i)) {
152-
auto to_inject_blocks = injectNodesForNonTensorInputs(segmented_blocks[i]);
153-
segmented_blocks.erase(segmented_blocks.begin() + i);
154-
segmented_blocks.insert(segmented_blocks.begin() + i, to_inject_blocks.begin(), to_inject_blocks.end());
155-
updated_segments.insert(i);
156-
}
176+
}
177+
// kTensorRT segments always need to inject nodes for the nonTensor inputs
178+
for (auto i : use_info.tensorrt_use_id) {
179+
if (!updated_segments.count(i)) {
180+
// Segmented Blocks with non-tensor inputs will have to be re-segmented as
181+
// TRTorch doesn't support non-tensor inputs for a module.
182+
auto to_inject_blocks = segmentBlocksWithNonTensorInputs(segmented_blocks[i]);
183+
auto next_iter = segmented_blocks_list.erase(idx_to_iter[i]);
184+
segmented_blocks_list.insert(next_iter, to_inject_blocks.begin(), to_inject_blocks.end());
185+
updated_segments.insert(i);
157186
}
158187
}
159188
}
189+
segmented_blocks.clear();
190+
segmented_blocks.insert(segmented_blocks.begin(), segmented_blocks_list.begin(), segmented_blocks_list.end());
160191
return;
161192
}
162193

0 commit comments

Comments
 (0)