Skip to content

Commit 43a53ce

Browse files
committed
Merge branch 'master' into dict_construct
2 parents d7d1511 + 8ec296c commit 43a53ce

28 files changed

+1023
-189
lines changed

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,8 @@ These are the following dependencies used to verify the testcases. Torch-TensorR
113113
- Bazel 4.2.1
114114
- Libtorch 1.10.0 (built with CUDA 11.3)
115115
- CUDA 11.3 (10.2 on Jetson)
116-
- cuDNN 8.2
117-
- TensorRT 8.0.3.4 (TensorRT 8.0.1.6 on Jetson)
116+
- cuDNN 8.2.1
117+
- TensorRT 8.2.4.2 (TensorRT 8.2.1 on Jetson)
118118

119119
## Prebuilt Binaries and Wheel files
120120

WORKSPACE

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,10 @@ http_archive(
8686
http_archive(
8787
name = "tensorrt",
8888
build_file = "@//third_party/tensorrt/archive:BUILD",
89-
sha256 = "da130296ac6636437ff8465812eb55dbab0621747d82dc4fe9b9376f00d214af",
90-
strip_prefix = "TensorRT-8.2.2.1",
89+
sha256 = "826180eaaecdf9a7e76116855b9f1f3400ea9b06e66b06a3f6a0747ba6f863ad",
90+
strip_prefix = "TensorRT-8.2.4.2",
9191
urls = [
92-
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.2.1/tars/tensorrt-8.2.2.1.linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz",
92+
"https://developer.nvidia.com/compute/machine-learning/tensorrt/secure/8.2.4/tars/tensorrt-8.2.4.2.linux.x86_64-gnu.cuda-11.4.cudnn8.2.tar.gz",
9393
],
9494
)
9595

core/compiler.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -344,9 +344,10 @@ void MapInputsAndDetermineDTypes(
344344
ss << "- Disable partial compilation by setting require_full_compilation to True";
345345
auto warn_str = ss.str();
346346
LOG_WARNING(warn_str);
347-
// Overwrite type map with user settings
348-
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
349347
}
348+
// Overwrite type map with user settings
349+
// We use this map for partitiioning since we need c10::ScalarTypes not nvinfer::DataTypes
350+
first_use_type_map[in] = {util::TRTDataTypeToScalarType(cfg.convert_info.inputs.find(in)->second.dtype)};
350351
}
351352
} else {
352353
// The user defined the type so no changes are necessary
@@ -417,18 +418,16 @@ torch::jit::Module CompileGraph(const torch::jit::Module& mod, CompileSpec cfg)
417418
auto first_use_types = ir::get_block_first_calc_dtypes_opt(g->block());
418419

419420
MapInputsAndDetermineDTypes(cfg, g, static_params, first_use_types);
420-
421+
auto isBlockConvertible = conversion::VerifyConverterSupportForBlock(g->block(), true);
421422
if (cfg.partition_info.enabled &&
422423
(cfg.lower_info.forced_fallback_modules.size() == 0 &&
423-
cfg.partition_info.forced_fallback_operators.size() == 0 &&
424-
conversion::VerifyConverterSupportForBlock(g->block(), true))) {
424+
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
425425
LOG_INFO("Skipping partitioning since model is fully supported");
426426
}
427427

428428
if (cfg.partition_info.enabled &&
429429
!(cfg.lower_info.forced_fallback_modules.size() == 0 &&
430-
cfg.partition_info.forced_fallback_operators.size() == 0 &&
431-
conversion::VerifyConverterSupportForBlock(g->block(), true))) {
430+
cfg.partition_info.forced_fallback_operators.size() == 0 && isBlockConvertible)) {
432431
auto input_ivalues_map = partitioning::generateRandomInputs(cfg.convert_info.inputs, first_use_types);
433432
auto graph_and_mapping = ConstructFallbackGraph(new_mod, g->block(), input_ivalues_map, cfg, static_params);
434433
new_g = graph_and_mapping.first;

core/conversion/conversion.cpp

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
105105
// Node input has not been converted yet or is a prim op
106106
TORCHTRT_THROW_ERROR(
107107
"Unable to retrieve all node inputs for node: "
108-
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: " << *input_node);
108+
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: %"
109+
<< input->debugName());
109110
}
110111
}
111112

@@ -533,18 +534,22 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
533534
if (unsupported_ops.size() != 0) {
534535
std::stringstream unsupported_msg;
535536
unsupported_msg
536-
<< "Method requested cannot be compiled by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
537+
<< "Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
537538
<< std::endl;
538539
for (auto s : unsupported_ops) {
539540
unsupported_msg << " - " << s.second << std::endl;
540541
}
541-
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
542-
<< std::endl;
543-
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
544-
unsupported_msg << std::endl << "In Module:" << std::endl;
545542

546543
if (!suppress_errors) {
544+
unsupported_msg
545+
<< "You can either implement converters for these ops in your application or request implementation"
546+
<< std::endl;
547+
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
548+
unsupported_msg << std::endl << "In Module:" << std::endl;
549+
547550
LOG_ERROR(unsupported_msg.str());
551+
} else {
552+
LOG_INFO(unsupported_msg.str());
548553
}
549554

550555
std::unordered_map<std::string, std::unordered_set<std::string>> unsupported_node_locations;
@@ -570,8 +575,13 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
570575
for (const auto& str : type.second) {
571576
traceback << str;
572577
}
578+
573579
auto tb_str = traceback.str();
574-
LOG_ERROR(tb_str);
580+
if (!suppress_errors) {
581+
LOG_ERROR(tb_str);
582+
} else {
583+
LOG_DEBUG(tb_str);
584+
}
575585
}
576586

577587
return false;

core/conversion/converters/impl/batch_norm.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,35 @@ auto batch_norm_registrations TORCHTRT_UNUSED =
5050
auto orig_shape = input->getDimensions();
5151
auto shape = util::toVec(orig_shape);
5252
auto tensor_type = util::TRTDataTypeToScalarType(input->getType());
53-
auto options = torch::TensorOptions().dtype(tensor_type);
53+
auto options =
54+
torch::TensorOptions().dtype(tensor_type).device(torch::kCUDA, ctx->settings.device.gpu_id);
5455

5556
torch::Tensor gamma, beta, mean, var;
57+
LOG_DEBUG("Input :" << orig_shape << "/" << input->getType());
58+
// affine=True
59+
LOG_DEBUG("Args[1] gamma : " << args[1].isIValue() << " / " << args[1].IValue()->isNone());
60+
LOG_DEBUG("Args[2] beta : " << args[2].isIValue() << " / " << args[2].IValue()->isNone());
61+
// track_running_stats=True
62+
LOG_DEBUG("Args[3] mean : " << args[3].isIValue() << " / " << args[3].IValue()->isNone());
63+
LOG_DEBUG("Args[4] var : " << args[4].isIValue() << " / " << args[4].IValue()->isNone());
64+
LOG_DEBUG("use_input_stats, momemtum, cudnn_enabled disregarded");
65+
LOG_DEBUG("ctx->input_is_dynamic : " << ctx->input_is_dynamic);
5666

67+
auto channel_dim = shape[1];
5768
if (ctx->input_is_dynamic) {
58-
gamma = args[1].unwrapToTensor();
59-
beta = args[2].unwrapToTensor();
69+
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
70+
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
6071
mean = args[3].unwrapToTensor();
6172
var = args[4].unwrapToTensor();
6273
} else {
63-
gamma = args[1].unwrapToTensor(at::full({shape}, 1, {options}));
64-
beta = args[2].unwrapToTensor(at::full({shape}, 1, {options}));
65-
mean = args[3].unwrapToTensor(at::full({shape}, 0, {options}));
66-
var = args[4].unwrapToTensor(at::full({shape}, 0, {options}));
74+
gamma = args[1].unwrapToTensor(at::full(channel_dim, 1, options));
75+
beta = args[2].unwrapToTensor(at::full(channel_dim, 0, options));
76+
mean = args[3].unwrapToTensor(at::full(channel_dim, 0, options));
77+
var = args[4].unwrapToTensor(at::full(channel_dim, 0, options));
6778
}
6879

6980
auto eps = static_cast<float>(args[7].unwrapToDouble(1e-5f));
7081

71-
LOG_DEBUG("momentum disregarded");
72-
LOG_DEBUG("training disregarded");
73-
LOG_DEBUG("cudnn disregarded");
7482
TORCHTRT_CHECK(orig_shape.nbDims >= 2, "Unable to create batch normalization layer from node: " << *n);
7583

7684
// Expand spatial dims from 1D to 2D if needed

core/conversion/converters/impl/select.cpp

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,34 @@ bool add_split(ConversionCtx* ctx, const torch::jit::Node* n, args& args, bool s
6767
return true;
6868
}
6969

70+
nvinfer1::ITensor* roll(
71+
ConversionCtx* ctx,
72+
nvinfer1::ITensor* in,
73+
int shift,
74+
int dim,
75+
const std::vector<int64_t>& in_shape) {
76+
auto in_dim = in_shape[dim];
77+
78+
auto start = (in_dim - shift) % in_dim;
79+
// Behavior of % is different in C++ vs Python for negative numbers. This
80+
// corrects the difference.
81+
if (start < 0) {
82+
start = start + in_dim;
83+
}
84+
at::Tensor index0 = at::arange(start, in_dim, 1, torch::kInt32);
85+
at::Tensor index;
86+
if (start == 0) {
87+
index = index0;
88+
} else {
89+
at::Tensor index1 = at::arange(start, torch::kInt32);
90+
index = at::cat({index0, index1}, 0);
91+
}
92+
auto index_tensor = tensor_to_const(ctx, index);
93+
auto gather_layer = ctx->net->addGather(*in, *index_tensor, dim);
94+
auto out = gather_layer->getOutput(0);
95+
return out;
96+
}
97+
7098
auto select_registrations TORCHTRT_UNUSED =
7199
RegisterNodeConversionPatterns()
72100
.pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))",
@@ -200,6 +228,69 @@ auto select_registrations TORCHTRT_UNUSED =
200228

201229
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
202230

231+
return true;
232+
}})
233+
.pattern({"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)",
234+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
235+
auto in = args[0].ITensor();
236+
auto shifts = args[1].unwrapToIntList().vec();
237+
auto dims = args[2].unwrapToIntList().vec();
238+
239+
TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()");
240+
if (ctx->input_is_dynamic) {
241+
TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation");
242+
} else {
243+
auto in_shape = util::toVec(in->getDimensions());
244+
for (size_t i = 0; i < dims.size(); i++) {
245+
auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i];
246+
TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range");
247+
in = roll(ctx, in, shifts[i], dim, in_shape);
248+
}
249+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
250+
251+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
252+
253+
return true;
254+
}
255+
}})
256+
.pattern(
257+
{"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)",
258+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
259+
auto in = args[0].ITensorOrFreeze(ctx);
260+
auto ts = args[1].IValue()->toListRef();
261+
262+
std::vector<nvinfer1::ITensor*> tensors;
263+
for (auto t : ts) {
264+
if (t.isTensor()) {
265+
auto torch_tensor = t.toTensor();
266+
tensors.push_back(tensor_to_const(ctx, torch_tensor));
267+
} else {
268+
auto cont = t.toCustomClass<TensorContainer>();
269+
tensors.push_back(cont->tensor());
270+
}
271+
}
272+
273+
// In TorchScript, aten::index.Tensor indexes the self tensor along its each dimension by several
274+
// indexes. In this version of Torch-TensorRT, it can only receive one index tensor which means it only
275+
// indexes the self tensor along dimension 0.
276+
TORCHTRT_CHECK(
277+
tensors.size() == 1,
278+
"In this version of Torch-TensorRT, aten::index.Tensor can only receive one index tensor which means it only indexes the self tensor along dimension 0.");
279+
auto indicesTensor = tensors[0];
280+
// Set datatype for indices tensor to INT32
281+
auto identity = ctx->net->addIdentity(*indicesTensor);
282+
identity->setOutputType(0, nvinfer1::DataType::kINT32);
283+
indicesTensor = identity->getOutput(0);
284+
285+
// IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices
286+
// from
287+
auto gather_layer = ctx->net->addGather(*in, *indicesTensor, 0);
288+
TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n);
289+
auto gather_out = gather_layer->getOutput(0);
290+
291+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out);
292+
293+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
203294
return true;
204295
}})
205296
.pattern(

core/conversion/converters/impl/stack.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,16 @@ auto stack_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().patt
1919
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
2020
auto in = args[0].IValue()->toListRef();
2121
auto dim = args[1].unwrapToInt();
22+
if (-1 == dim) {
23+
auto first_in = in[0];
24+
if (first_in.isTensor()) {
25+
dim = first_in.toTensor().ndimension();
26+
} else {
27+
dim = first_in.toCustomClass<TensorContainer>()->tensor()->getDimensions().nbDims;
28+
}
29+
}
2230

2331
std::vector<nvinfer1::ITensor*> tensors;
24-
2532
for (auto t : in) {
2633
nvinfer1::ITensor* itensor;
2734

core/conversion/evaluators/aten.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <math.h>
2+
13
#include "ATen/core/List.h"
24
#include "ATen/core/functional.h"
35
#include "ATen/core/ivalue.h"
@@ -98,6 +100,17 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
98100
"aten::ge.float_int(float a, int b) -> (bool)",
99101
}));
100102

103+
DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
104+
pow,
105+
"aten::pow",
106+
pow(a, b),
107+
std::set<std::string>({
108+
"aten::pow.int(int a, int b) -> (float)",
109+
"aten::pow.float(float a, float b) -> (float)",
110+
"aten::pow.int_float(int a, float b) -> (float)",
111+
"aten::pow.float_int(float a, int b) -> (float)",
112+
}));
113+
101114
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
102115
and,
103116
"aten::__and__",

core/conversion/evaluators/eval_macros.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,53 @@
7777
}, \
7878
EvalOptions().validSchemas(schemas)});
7979

80+
#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
81+
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
82+
{c10::Symbol::fromQualString(node_kind), \
83+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
84+
if (args.at(n->input(0)).IValue()->isInt()) { \
85+
auto a = args.at(n->input(0)).unwrapToInt(); \
86+
if (args.at(n->input(1)).IValue()->isInt()) { \
87+
auto b = args.at(n->input(1)).unwrapToInt(); \
88+
return operation; \
89+
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
90+
auto b = args.at(n->input(1)).unwrapToDouble(); \
91+
return operation; \
92+
} else if (args.at(n->input(1)).IValue()->isBool()) { \
93+
auto b = args.at(n->input(1)).unwrapToBool(); \
94+
return operation; \
95+
} else { \
96+
TORCHTRT_THROW_ERROR( \
97+
"Unimplemented data type for " \
98+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
99+
return {}; \
100+
} \
101+
} else if (args.at(n->input(0)).IValue()->isDouble()) { \
102+
auto a = args.at(n->input(0)).unwrapToDouble(); \
103+
if (args.at(n->input(1)).IValue()->isInt()) { \
104+
auto b = args.at(n->input(1)).unwrapToInt(); \
105+
return operation; \
106+
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
107+
auto b = args.at(n->input(1)).unwrapToDouble(); \
108+
return operation; \
109+
} else if (args.at(n->input(1)).IValue()->isBool()) { \
110+
auto b = args.at(n->input(1)).unwrapToBool(); \
111+
return operation; \
112+
} else { \
113+
TORCHTRT_THROW_ERROR( \
114+
"Unimplemented data type for " \
115+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
116+
return {}; \
117+
} \
118+
} else { \
119+
TORCHTRT_THROW_ERROR( \
120+
"Unimplemented data type for " \
121+
<< node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \
122+
return {}; \
123+
} \
124+
}, \
125+
EvalOptions().validSchemas(schemas)});
126+
80127
#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
81128
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
82129
{c10::Symbol::fromQualString(node_name), \

core/lowering/lowering.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4444
passes::EliminateExceptionOrPassPattern(g);
4545
passes::ReduceToOperation(g);
4646
passes::ReduceGelu(g);
47+
passes::ReduceRemainder(g);
4748
passes::RemoveContiguous(g);
4849
passes::ViewToReshape(g);
4950
passes::RemoveDropout(g);

0 commit comments

Comments
 (0)