Skip to content

Commit 17340fb

Browse files
authored
Merge branch 'master' into trt8.4
2 parents c009a1f + d6a2b88 commit 17340fb

30 files changed

+1013
-92
lines changed

core/compiler.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,8 @@ void MapInputsAndDetermineDTypes(
328328
spec.dtype = nvinfer1::DataType::kFLOAT;
329329
} else if (spec.dtype_is_user_defined && cfg.partition_info.enabled) {
330330
if (!est_type_opt) {
331-
LOG_INFO("Cannot infer input tensor dtype in graph, unable to verify user input dtype settings");
331+
LOG_INFO("Cannot infer input tensor dtype in graph. Using user provided input dtype settings");
332+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
332333
} else {
333334
if (util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype) != est_type_opt.value()) {
334335
std::stringstream ss;

core/conversion/conversion.cpp

Lines changed: 32 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
#include "core/conversion/conversion.h"
2+
#include <ATen/core/operator_name.h>
23
#include <torch/torch.h>
34
#include <sstream>
5+
#include "c10/util/intrusive_ptr.h"
46
#include "core/conversion/conversionctx/ConversionCtx.h"
7+
#include "core/conversion/converters/converter_util.h"
58
#include "core/conversion/converters/converters.h"
69
#include "core/conversion/evaluators/evaluators.h"
10+
#include "core/conversion/tensorcontainer/TensorContainer.h"
711
#include "core/conversion/var/Var.h"
812
#include "core/util/prelude.h"
9-
10-
#include "c10/util/intrusive_ptr.h"
11-
#include "core/conversion/converters/converter_util.h"
12-
#include "core/conversion/tensorcontainer/TensorContainer.h"
1313
#include "core/util/trt_util.h"
1414

1515
namespace torch_tensorrt {
@@ -427,10 +427,18 @@ void ConvertBlockToNetDef(
427427
<< " and node outputs size: " << n->outputs().size() << " must match.");
428428
for (size_t i = 0; i < eval_list->elements().size(); i++) {
429429
auto eval_output = eval_list.get()->elements()[i];
430-
LOG_DEBUG(
431-
ctx->logger,
432-
"Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info(n));
433-
ctx->AssociateValueAndIValue(n->output(i), eval_output);
430+
if (eval_output.isCustomClass()) {
431+
auto container = eval_output.toCustomClass<TensorContainer>();
432+
auto tensor = container->tensor();
433+
LOG_DEBUG(
434+
ctx->logger, "Found the evaluated value(s) to be an ITensor of shape: " << tensor->getDimensions());
435+
ctx->AssociateValueAndTensor(n->output(i), tensor);
436+
} else {
437+
LOG_DEBUG(
438+
ctx->logger,
439+
"Found the evaluated value(s) to be " << eval_output << " for node: " << util::node_info(n));
440+
ctx->AssociateValueAndIValue(n->output(i), eval_output);
441+
}
434442
}
435443
} else {
436444
TORCHTRT_THROW_ERROR("Unsupported return type for evaluated node");
@@ -488,15 +496,23 @@ std::string ConvertBlockToEngine(
488496
std::unordered_map<c10::OperatorName, std::string> GetUnsupportedOpsInBlock(const torch::jit::Block* b) {
489497
std::unordered_map<c10::OperatorName, std::string> unsupported_ops;
490498
for (const auto n : b->nodes()) {
491-
if (n->kind() != torch::jit::prim::Loop && n->kind() != torch::jit::prim::If && !OpSupported(n)) {
492-
auto schema = n->maybeSchema();
493-
TORCHTRT_CHECK(
494-
schema,
495-
"Unable to get schema for Node " << util::node_info(n) << " (conversion.VerifyCoverterSupportForBlock)");
496-
std::stringstream ss;
497-
ss << *schema;
498-
unsupported_ops[schema->operator_name()] = ss.str();
499+
auto schema = n->maybeSchema();
500+
// Some ops like torch::jit::prim::Loop, torch::jit::prim::If, torch::jit::prim::DictConstruct don't have a schema
501+
// but they are supported. torch::jit::prim::DictConstruct is supported via fallback only
502+
if (!OpSupported(n)) {
503+
if (schema) {
504+
std::stringstream ss;
505+
ss << *schema;
506+
unsupported_ops[schema->operator_name()] = ss.str();
507+
} else {
508+
std::stringstream ss;
509+
ss << util::node_info(n);
510+
// operator.overload is a filler name just to call the constructor.
511+
c10::OperatorName op(ss.str(), "operator.overload");
512+
unsupported_ops[op] = ss.str();
513+
}
499514
}
515+
500516
for (const auto sub_b : n->blocks()) {
501517
auto sub_b_unsupported_ops = GetUnsupportedOpsInBlock(sub_b);
502518
unsupported_ops.insert(sub_b_unsupported_ops.begin(), sub_b_unsupported_ops.end());
@@ -531,7 +547,6 @@ std::set<std::string> ConvertableOpsInBlock(const torch::jit::Block* b) {
531547

532548
bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_errors) {
533549
auto unsupported_ops = GetUnsupportedOpsInBlock(b);
534-
535550
if (unsupported_ops.size() != 0) {
536551
std::stringstream unsupported_msg;
537552
unsupported_msg

core/conversion/converters/impl/activation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,27 @@ auto acthardtanh TORCHTRT_UNUSED =
8787

8888
bool to_reshape = false;
8989
auto original_shape = in->getDimensions();
90+
91+
// Out_tensor of ParametricReLU shape is all 0, when slopes nDims is not equal to in nDims.
92+
// Since make sure splopes nDims is equal to in nDims.
93+
if (slopes.ndimension() == 1 and original_shape.nbDims != slopes.ndimension()) {
94+
std::vector<int64_t> slopes_new_shape(original_shape.nbDims, 1);
95+
auto first_inputs_allowed_formats = ctx->net->getInput(0)->getAllowedFormats();
96+
for (size_t inputs_index = 1; inputs_index < ctx->num_inputs; inputs_index++) {
97+
auto inputs_allowed_formats = ctx->net->getInput(inputs_index)->getAllowedFormats();
98+
TORCHTRT_CHECK(
99+
first_inputs_allowed_formats == inputs_allowed_formats,
100+
"Unable to create batch prelu layer from node,since the formats(like NHWC or NCHW) of inputs is different: "
101+
<< *n);
102+
}
103+
if (1U << static_cast<int>(nvinfer1::TensorFormat::kLINEAR) == first_inputs_allowed_formats) {
104+
slopes_new_shape[1] = slopes.sizes().vec()[0];
105+
} else {
106+
slopes_new_shape[original_shape.nbDims - 1] = slopes.sizes().vec()[0];
107+
}
108+
slopes = slopes.reshape(slopes_new_shape);
109+
}
110+
90111
if (slopes.numel() != 1 &&
91112
!util::broadcastable(
92113
in->getDimensions(),

core/ir/Input.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ bool valid_dtype_format_combo(nvinfer1::DataType dtype, nvinfer1::TensorFormat f
4040
default:
4141
return false;
4242
}
43+
case nvinfer1::DataType::kBOOL: // Supports Linear (NCHW)
44+
switch (format) {
45+
case nvinfer1::TensorFormat::kLINEAR:
46+
return true;
47+
default:
48+
return false;
49+
}
4350
default:
4451
return false;
4552
}
@@ -48,7 +55,7 @@ bool valid_dtype_format_combo(nvinfer1::DataType dtype, nvinfer1::TensorFormat f
4855
bool valid_input_dtype(nvinfer1::DataType dtype) {
4956
switch (dtype) {
5057
case nvinfer1::DataType::kBOOL:
51-
return false;
58+
return true;
5259
case nvinfer1::DataType::kFLOAT:
5360
return true;
5461
case nvinfer1::DataType::kHALF:
@@ -153,4 +160,4 @@ std::ostream& operator<<(std::ostream& os, const Input& input) {
153160

154161
} // namespace ir
155162
} // namespace core
156-
} // namespace torch_tensorrt
163+
} // namespace torch_tensorrt

core/lowering/lowering.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "torch/csrc/jit/passes/common_subexpression_elimination.h"
22
#include "torch/csrc/jit/passes/create_functional_graphs.h"
33
#include "torch/csrc/jit/passes/dead_code_elimination.h"
4+
#include "torch/csrc/jit/passes/erase_number_types.h"
45
#include "torch/csrc/jit/passes/freeze_module.h"
56
#include "torch/csrc/jit/passes/fuse_linear.h"
67
#include "torch/csrc/jit/passes/guard_elimination.h"
@@ -64,6 +65,8 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
6465
passes::RemoveNOPs(g);
6566
passes::AliasOperators(g);
6667
passes::SiluToSigmoidMultipication(g);
68+
passes::RemoveSingleUse0DTensors(g);
69+
passes::RemoveUnnecessaryCasts(g);
6770
LOG_GRAPH(*g);
6871
}
6972

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
"view_to_reshape.cpp",
2525
"remove_dropout.cpp",
2626
"remove_nops.cpp",
27+
"remove_unnecessary_casts.cpp",
2728
"silu_to_sigmoid_multiplication.cpp",
2829
"unpack_addmm.cpp",
2930
"unpack_batch_norm.cpp",

core/lowering/passes/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ void RemoveContiguous(std::shared_ptr<torch::jit::Graph>& graph);
2828
void ViewToReshape(std::shared_ptr<torch::jit::Graph>& graph);
2929
void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph);
3030
void RemoveNOPs(std::shared_ptr<torch::jit::Graph> graph);
31+
void RemoveSingleUse0DTensors(std::shared_ptr<torch::jit::Graph>& g);
32+
void RemoveUnnecessaryCasts(std::shared_ptr<torch::jit::Graph>& graph);
3133
void UnpackAddMM(std::shared_ptr<torch::jit::Graph>& graph);
3234
void UnpackBatchNorm(std::shared_ptr<torch::jit::Graph>& graph);
3335
void UnpackLogSoftmax(std::shared_ptr<torch::jit::Graph>& graph);

core/lowering/passes/reduce_gelu.cpp

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,17 @@ namespace passes {
88

99
void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
1010
std::string gelu_pattern = R"IR(
11-
graph(%x):
11+
graph(%x : Tensor):
1212
%out : Tensor = aten::gelu(%x)
1313
return (%out))IR";
1414

15+
// This gelu_approximate_pattern schema exists in 21.11, 21.12, 22.01 containers of pytorch. These container versions use
16+
// an unmerged PR in pytorch : https://github.com/pytorch/pytorch/pull/61439. We reduce this to regular Gelu.
17+
std::string gelu_approximate_pattern = R"IR(
18+
graph(%x : Tensor, %approx):
19+
%out : Tensor = aten::gelu(%x, %approx)
20+
return (%out))IR";
21+
1522
std::string gelu_reduce_pattern = R"IR(
1623
graph(%x.1 : Tensor):
1724
%6 : float = prim::Constant[value=0.044714999999999998]()
@@ -30,11 +37,36 @@ void ReduceGelu(std::shared_ptr<torch::jit::Graph>& graph) {
3037
%15 : Tensor = aten::mul(%7, %14)
3138
return (%15))IR";
3239

40+
// This is same as gelu_reduce_pattern except for an additional input %approx.
41+
// SubgraphRewriter only works as expected if the number of inputs to gelu_approximate_pattern
42+
// and gelu_reduce_multi_input_pattern are same.
43+
std::string gelu_reduce_multi_input_pattern = R"IR(
44+
graph(%x.1 : Tensor, %approx):
45+
%6 : float = prim::Constant[value=0.044714999999999998]()
46+
%5 : float = prim::Constant[value=0.79788456080000003]()
47+
%4 : float = prim::Constant[value=1.]()
48+
%3 : float = prim::Constant[value=0.5]()
49+
%2 : int = prim::Constant[value=1]()
50+
%7 : Tensor = aten::mul(%x.1, %3)
51+
%8 : Tensor = aten::mul(%x.1, %5)
52+
%9 : Tensor = aten::mul(%x.1, %6)
53+
%10 : Tensor = aten::mul(%9, %x.1)
54+
%11 : Tensor = aten::add(%10, %4, %2)
55+
%12 : Tensor = aten::mul(%8, %11)
56+
%13 : Tensor = aten::tanh(%12)
57+
%14 : Tensor = aten::add(%13, %4, %2)
58+
%15 : Tensor = aten::mul(%7, %14)
59+
return (%15))IR";
60+
3361
// replace aten::gelu with pointwise operations
3462
torch::jit::SubgraphRewriter map_gelu_to_pointwise_ops;
3563
map_gelu_to_pointwise_ops.RegisterRewritePattern(gelu_pattern, gelu_reduce_pattern);
3664
map_gelu_to_pointwise_ops.runOnGraph(graph);
3765

66+
torch::jit::SubgraphRewriter map_gelu_approximate_to_pointwise_ops;
67+
map_gelu_approximate_to_pointwise_ops.RegisterRewritePattern(gelu_approximate_pattern, gelu_reduce_multi_input_pattern);
68+
map_gelu_approximate_to_pointwise_ops.runOnGraph(graph);
69+
3870
LOG_GRAPH("Post lowering of [aten::gelu] -> " << *graph);
3971
}
4072

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include <stack>
2+
#include <unordered_set>
3+
4+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
5+
6+
#include "core/lowering/passes/passes.h"
7+
#include "core/util/prelude.h"
8+
9+
namespace torch_tensorrt {
10+
namespace core {
11+
namespace lowering {
12+
namespace passes {
13+
14+
void RemoveSetAttrs(const torch::jit::Module& mod, std::string method_name) {
15+
auto g = mod.get_method(method_name).graph();
16+
17+
std::string set_attr_pattern = R"IR(
18+
graph(%self, %0):
19+
None = prim::SetAttr[name="_has_warned"](%self, %0)
20+
return ())IR";
21+
std::string no_set_attr_pattern = R"IR(
22+
graph(%self, %0):
23+
return ())IR";
24+
25+
// remove contiguous
26+
torch::jit::SubgraphRewriter remove_set_attr;
27+
remove_set_attr.RegisterRewritePattern(set_attr_pattern, no_set_attr_pattern);
28+
remove_set_attr.runOnGraph(g);
29+
LOG_GRAPH("Post remove contiguous: " << *g);
30+
}
31+
32+
} // namespace passes
33+
} // namespace lowering
34+
} // namespace core
35+
} // namespace torch_tensorrt

0 commit comments

Comments
 (0)