Skip to content

Commit cbe04cb

Browse files
authored
Merge branch 'master' into dyn_shapes
2 parents 6d0b0f6 + e3b9929 commit cbe04cb

File tree

128 files changed

+786
-146
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

128 files changed

+786
-146
lines changed

core/conversion/conversionctx/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ cc_library(
2121
deps = [
2222
"@tensorrt//:nvinfer",
2323
"//core/util:prelude",
24+
"//core/ir",
2425
] + select({
2526
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2627
"//conditions:default": ["@libtorch//:libtorch"],

core/conversion/conversionctx/ConversionCtx.h

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,28 +9,21 @@
99
#include "torch/csrc/jit/ir/ir.h"
1010

1111
#include <cuda_runtime.h>
12+
#include "core/ir/ir.h"
1213
#include "core/util/prelude.h"
1314

1415
namespace torch_tensorrt {
1516
namespace core {
1617
namespace conversion {
1718

18-
struct Device {
19-
nvinfer1::DeviceType device_type;
20-
int64_t gpu_id;
21-
int64_t dla_core;
22-
bool allow_gpu_fallback;
23-
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
24-
};
25-
2619
struct BuilderSettings {
2720
std::set<nvinfer1::DataType> enabled_precisions = {};
2821
bool sparse_weights = false;
2922
bool disable_tf32 = false;
3023
bool refit = false;
3124
bool debug = false;
3225
bool truncate_long_and_double = false;
33-
Device device;
26+
ir::Device device;
3427
nvinfer1::EngineCapability capability = TRT_ENGINE_CAPABILITY_STANDARD;
3528
nvinfer1::IInt8Calibrator* calibrator = nullptr;
3629
uint64_t num_avg_timing_iters = 1;

core/conversion/converters/impl/element_wise.cpp

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -166,11 +166,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
166166
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
167167
// Should implement self - alpha * other
168168
auto self = args[0].ITensorOrFreeze(ctx);
169-
auto scalar = args[2].unwrapToScalar().to<float>();
170169
auto other = args[1].ITensorOrFreeze(ctx);
170+
auto scalar = args[2].unwrapToScalar();
171171

172-
if (1 != scalar) {
173-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
172+
if (1 != scalar.to<float>()) {
173+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
174174
auto scaleLayer = add_elementwise(
175175
ctx,
176176
nvinfer1::ElementWiseOperation::kPROD,
@@ -214,11 +214,11 @@ auto element_wise_registrations TORCHTRT_UNUSED =
214214
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
215215
// Should implement self - alpha * other
216216
auto self = args[0].ITensorOrFreeze(ctx);
217-
auto scalar = args[2].unwrapToScalar().to<float>();
218217
auto other = args[1].ITensorOrFreeze(ctx);
218+
auto scalar = args[2].unwrapToScalar();
219219

220-
if (1 != scalar) {
221-
auto alphaTensor = tensor_to_const(ctx, torch::tensor({scalar}));
220+
if (1 != scalar.to<float>()) {
221+
auto alphaTensor = scalar_to_tensor(ctx, scalar);
222222
auto scaleLayer = add_elementwise(
223223
ctx,
224224
nvinfer1::ElementWiseOperation::kPROD,
@@ -351,8 +351,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
351351
{"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
352352
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
353353
auto self = args[0].ITensorOrFreeze(ctx);
354-
auto otherScalar = args[1].unwrapToScalar().to<float>();
355-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
354+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
356355
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
357356
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
358357

@@ -381,8 +380,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
381380
{"aten::div_.Scalar(Tensor(a!) self, Scalar other) -> Tensor(a!)",
382381
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
383382
auto self = args[0].ITensorOrFreeze(ctx);
384-
auto otherScalar = args[1].unwrapToScalar().to<float>();
385-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
383+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
386384
auto div = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
387385
TORCHTRT_CHECK(div, "Unable to create div layer from node: " << *n);
388386

@@ -481,18 +479,12 @@ auto element_wise_registrations TORCHTRT_UNUSED =
481479
{"aten::ne.Scalar(Tensor self, Scalar other) -> (Tensor)",
482480
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
483481
auto self = args[0].ITensorOrFreeze(ctx);
484-
auto scalar = args[1].unwrapToScalar();
485-
nvinfer1::ITensor* scalar_tensor;
486-
if (self->getType() == nvinfer1::DataType::kFLOAT || self->getType() == nvinfer1::DataType::kHALF) {
487-
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<float>()}));
488-
} else {
489-
scalar_tensor = tensor_to_const(ctx, torch::tensor({scalar.to<int>()}));
490-
}
482+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
491483
auto equal = add_elementwise(
492484
ctx,
493485
nvinfer1::ElementWiseOperation::kEQUAL,
494486
self,
495-
scalar_tensor,
487+
other,
496488
util::node_info(n) + std::string("is_equal"));
497489
TORCHTRT_CHECK(equal, "Unable to create elementwise equal layer from node: " << *n);
498490
// XOR with ones negates and produces not_equal result
@@ -534,8 +526,7 @@ auto element_wise_registrations TORCHTRT_UNUSED =
534526
{"aten::pow.Tensor_Scalar(Tensor self, Scalar exponent) -> (Tensor)",
535527
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
536528
auto self = args[0].ITensorOrFreeze(ctx);
537-
auto exponentScalar = args[1].unwrapToScalar().to<float>();
538-
auto exponent = tensor_to_const(ctx, torch::tensor({exponentScalar}));
529+
auto exponent = scalar_to_tensor(ctx, args[1].unwrapToScalar());
539530
auto pow =
540531
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPOW, self, exponent, util::node_info(n));
541532
TORCHTRT_CHECK(pow, "Unable to create Power layer from node: " << *n);
@@ -681,9 +672,9 @@ auto element_wise_registrations TORCHTRT_UNUSED =
681672
{"aten::eq.Scalar(Tensor self, Scalar other) -> (Tensor)",
682673
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
683674
auto self = args[0].ITensorOrFreeze(ctx);
684-
auto otherScalar = args[1].unwrapToScalar().to<float>();
685-
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
675+
auto other = scalar_to_tensor(ctx, args[1].unwrapToScalar());
686676
if (self->getType() == nvinfer1::DataType::kBOOL) {
677+
auto otherScalar = args[1].unwrapToScalar().to<float>();
687678
if (otherScalar == 0 || otherScalar == 1) {
688679
LOG_DEBUG("Since input tensor is type bool, casting input tensor and scalar to int32");
689680
other = castITensor(ctx, other, nvinfer1::DataType::kINT32);

core/ir/ir.h

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@ enum class ShapeMode {
1717
kMAX,
1818
};
1919

20+
struct Device {
21+
nvinfer1::DeviceType device_type;
22+
int64_t gpu_id;
23+
int64_t dla_core;
24+
bool allow_gpu_fallback;
25+
Device() : device_type(nvinfer1::DeviceType::kGPU), gpu_id(0), dla_core(0), allow_gpu_fallback(false) {}
26+
};
27+
2028
struct Input : torch::CustomClassHolder {
2129
Input(){};
2230
Input(

core/lowering/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ cc_library(
2424
deps = [
2525
"//core/lowering/passes",
2626
"//core/util:prelude",
27+
"//core/ir",
2728
] + select({
2829
":use_pre_cxx11_abi": ["@libtorch_pre_cxx11_abi//:libtorch"],
2930
"//conditions:default": ["@libtorch//:libtorch"],

core/lowering/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ set(HEADER_FILES
1515
target_sources(${lib_name}
1616
PRIVATE
1717
${CXX_SRCS}
18+
PUBLIC
19+
$<TARGET_OBJECTS:core_ir>
1820
$<TARGET_OBJECTS:core_util>
1921
)
2022

@@ -25,8 +27,9 @@ target_include_directories(${lib_name}
2527

2628
target_link_libraries(${lib_name}
2729
PUBLIC
30+
TensorRT::nvinfer
2831
torch
29-
PRIVATE
32+
core_ir
3033
core_util
3134
)
3235

core/lowering/lowering.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ void LowerBlock(torch::jit::Block* b) {
2626
DropUnusedNodes(b);
2727
}
2828

29-
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
29+
void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params, LowerInfo lower_info) {
3030
torch::jit::EliminateRedundantGuards(g);
3131
torch::jit::RemoveListMutation(g);
3232
torch::jit::RemoveTensorMutation(g);
@@ -70,6 +70,11 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
7070
passes::SiluToSigmoidMultipication(g);
7171
passes::RemoveSingleUse0DTensors(g);
7272
passes::RemoveUnnecessaryCasts(g);
73+
passes::UnpackAndCastMaskedFill(g, lower_info.getGPUDeviceString());
74+
passes::UnpackAndCastNumToTensor(g, lower_info.getGPUDeviceString());
75+
passes::UnpackAndCastFull(g, lower_info.getGPUDeviceString());
76+
passes::ReplaceScalarImplicit(g);
77+
passes::RewriteInputsWithParams(g, params);
7378
LOG_GRAPH(*g);
7479
}
7580

@@ -103,7 +108,7 @@ std::pair<std::shared_ptr<torch::jit::Graph>, std::vector<torch::jit::IValue>> L
103108
// In quantization aware trained (QAT) models, weights are passed through quantize and
104109
// dequantize nodes which should not be folded. So unfreeze_module is set to True for QAT models.
105110
LOG_GRAPH("Torch-TensorRT.TorchScript Graph Lowering");
106-
lowering::LowerGraph(graph_and_ivalues.first, lower_info);
111+
lowering::LowerGraph(graph_and_ivalues.first, graph_and_ivalues.second, lower_info);
107112

108113
// Is this necessary?
109114
// lowering::LowerBlock(g->block());

core/lowering/lowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22
#include <memory>
3+
#include "core/ir/ir.h"
34
#include "torch/csrc/jit/ir/ir.h"
45

56
namespace torch_tensorrt {
@@ -15,8 +16,13 @@ struct LowerInfo {
1516
// Since these QDQ nodes will be identical as they share same input, one of them is eliminated due to CSE lowering
1617
// pass. Disable this in order to not disturb TensorRT's QAT optimizations.
1718
bool disable_cse = false;
19+
ir::Device target_device;
1820
std::vector<std::string> forced_fallback_modules;
1921
friend std::ostream& operator<<(std::ostream& os, const LowerInfo& l);
22+
23+
std::string getGPUDeviceString() {
24+
return "cuda:" + std::to_string(target_device.gpu_id);
25+
};
2026
};
2127

2228
void LowerBlock(torch::jit::Block* b);

core/lowering/passes/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ cc_library(
1414
name = "passes",
1515
srcs = [
1616
"convNd_to_convolution.cpp",
17+
"device_casting.cpp",
1718
"exception_elimination.cpp",
1819
"fuse_addmm_branches.cpp",
1920
"linear_to_addmm.cpp",
@@ -27,6 +28,7 @@ cc_library(
2728
"remove_dropout.cpp",
2829
"remove_nops.cpp",
2930
"remove_unnecessary_casts.cpp",
31+
"rewrite_inputs_with_params.cpp",
3032
"silu_to_sigmoid_multiplication.cpp",
3133
"unpack_addmm.cpp",
3234
"unpack_batch_norm.cpp",

core/lowering/passes/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
target_sources(${lib_name}
22
PRIVATE "${CMAKE_CURRENT_SOURCE_DIR}/convNd_to_convolution.cpp"
3+
"${CMAKE_CURRENT_SOURCE_DIR}/device_casting.cpp"
34
"${CMAKE_CURRENT_SOURCE_DIR}/exception_elimination.cpp"
45
"${CMAKE_CURRENT_SOURCE_DIR}/fuse_addmm_branches.cpp"
56
"${CMAKE_CURRENT_SOURCE_DIR}/linear_to_addmm.cpp"
@@ -24,6 +25,7 @@ target_sources(${lib_name}
2425
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_std.cpp"
2526
"${CMAKE_CURRENT_SOURCE_DIR}/unpack_var.cpp"
2627
"${CMAKE_CURRENT_SOURCE_DIR}/view_to_reshape.cpp"
28+
"${CMAKE_CURRENT_SOURCE_DIR}/rewrite_inputs_with_params.cpp"
2729
)
2830

2931
set(HEADER_FILES
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
#include "torch/csrc/jit/ir/constants.h"
2+
#include "torch/csrc/jit/passes/subgraph_rewrite.h"
3+
4+
#include "core/util/prelude.h"
5+
6+
namespace torch_tensorrt {
7+
namespace core {
8+
namespace lowering {
9+
namespace passes {
10+
11+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
12+
std::string masked_fill_pattern = R"IR(
13+
graph(%self, %mask, %value):
14+
%out: Tensor = aten::masked_fill_(%self, %mask, %value)
15+
return (%out))IR";
16+
17+
// Calls to masked_fill_ often utilize CPU tensors, and as such
18+
// should be moved to gpu to avoid device mismatch errors
19+
20+
// Separate string into portions to insert device name
21+
std::string clean_pattern_part_1 = R"IR(
22+
graph(%self, %mask, %value):
23+
%device: Device = prim::Constant[value=")IR";
24+
25+
std::string clean_pattern_part_2 = R"IR("]()
26+
%dtype: NoneType = prim::Constant()
27+
%false: bool = prim::Constant[value=0]()
28+
%mask_cuda: Tensor = aten::to(%mask, %device, %dtype, %false, %false)
29+
%self_cuda: Tensor = aten::to(%self, %device, %dtype, %false, %false)
30+
%out: Tensor = aten::masked_fill(%self_cuda, %mask_cuda, %value)
31+
return (%out))IR";
32+
33+
auto unpacked_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
34+
35+
torch::jit::SubgraphRewriter masked_fill_rewriter;
36+
masked_fill_rewriter.RegisterRewritePattern(masked_fill_pattern, unpacked_pattern);
37+
masked_fill_rewriter.runOnGraph(graph);
38+
LOG_GRAPH("After unpack and cast masked_fill_: " << *graph);
39+
}
40+
41+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
42+
std::string num_to_tensor_cast_pattern = R"IR(
43+
graph(%1: Scalar):
44+
%2: Tensor = prim::NumToTensor(%1)
45+
return (%2))IR";
46+
47+
// 0D Tensors are initialized on cpu, and need to be moved to gpu
48+
// to avoid device mismatch issues
49+
50+
// Separate string into portions to insert device name
51+
std::string clean_pattern_part_1 = R"IR(
52+
graph(%1: Scalar):
53+
%2: Tensor = prim::NumToTensor(%1)
54+
%device: Device = prim::Constant[value=")IR";
55+
56+
std::string clean_pattern_part_2 = R"IR("]()
57+
%dtype: NoneType = prim::Constant()
58+
%false: bool = prim::Constant[value=0]()
59+
%3: Tensor = aten::to(%2, %device, %dtype, %false, %false)
60+
return (%3))IR";
61+
62+
auto num_to_tensor_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
63+
64+
torch::jit::SubgraphRewriter num_to_tensor_cast_rewriter;
65+
num_to_tensor_cast_rewriter.RegisterRewritePattern(num_to_tensor_cast_pattern, num_to_tensor_clean_pattern);
66+
num_to_tensor_cast_rewriter.runOnGraph(graph);
67+
68+
LOG_GRAPH("After unpack and cast NumToTensor: " << *graph);
69+
}
70+
71+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name) {
72+
std::string full_cast_pattern = R"IR(
73+
graph(%1, %2, %3, %4, %5, %6):
74+
%out: Tensor = aten::full(%1, %2, %3, %4, %5, %6)
75+
return (%out))IR";
76+
77+
// Tensors created via aten::full are initialized on cpu, and need to be casted to gpu
78+
// to avoid device mismatch issues
79+
80+
// Separate string into portions to insert device name
81+
std::string clean_pattern_part_1 = R"IR(
82+
graph(%1, %2, %3, %4, %5, %6):
83+
%device: Device = prim::Constant[value=")IR";
84+
85+
std::string clean_pattern_part_2 = R"IR("]()
86+
%out: Tensor = aten::full(%1, %2, %3, %4, %device, %6)
87+
return (%out))IR";
88+
89+
auto full_clean_pattern = clean_pattern_part_1 + target_device_name + clean_pattern_part_2;
90+
91+
torch::jit::SubgraphRewriter full_cast_rewriter;
92+
full_cast_rewriter.RegisterRewritePattern(full_cast_pattern, full_clean_pattern);
93+
full_cast_rewriter.runOnGraph(graph);
94+
95+
LOG_GRAPH("After unpack and cast full: " << *graph);
96+
}
97+
98+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph) {
99+
std::string scalar_implicit_cast_pattern = R"IR(
100+
graph(%1: Tensor):
101+
%2: Scalar = aten::ScalarImplicit(%1)
102+
return (%2))IR";
103+
104+
// ScalarImplicit can only unpack 0D tensors, whereas Tensors operated on by
105+
// TensorRT are padded to 1 dimension. aten::item() resolves this conflict
106+
std::string scalar_implicit_clean_pattern = R"IR(
107+
graph(%1: Tensor):
108+
%2: Scalar = aten::item(%1)
109+
return (%2))IR";
110+
111+
torch::jit::SubgraphRewriter scalar_implicit_cast_rewriter;
112+
scalar_implicit_cast_rewriter.RegisterRewritePattern(scalar_implicit_cast_pattern, scalar_implicit_clean_pattern);
113+
scalar_implicit_cast_rewriter.runOnGraph(graph);
114+
115+
LOG_GRAPH("After unpack and cast full: " << *graph);
116+
}
117+
118+
} // namespace passes
119+
} // namespace lowering
120+
} // namespace core
121+
} // namespace torch_tensorrt

core/lowering/passes/passes.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,12 @@ void UnpackVar(std::shared_ptr<torch::jit::Graph>& graph);
3939
void AliasOperators(std::shared_ptr<torch::jit::Graph>& graph);
4040
void SiluToSigmoidMultipication(std::shared_ptr<torch::jit::Graph>& graph);
4141
void UnpackHardSwish(std::shared_ptr<torch::jit::Graph>& graph);
42+
void RewriteInputsWithParams(std::shared_ptr<torch::jit::Graph>& g, std::vector<torch::jit::IValue>& params);
4243
void UnpackHardSigmoid(std::shared_ptr<torch::jit::Graph>& graph);
44+
void UnpackAndCastMaskedFill(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
45+
void UnpackAndCastNumToTensor(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
46+
void UnpackAndCastFull(std::shared_ptr<torch::jit::Graph>& graph, std::string target_device_name);
47+
void ReplaceScalarImplicit(std::shared_ptr<torch::jit::Graph>& graph);
4348

4449
} // namespace passes
4550
} // namespace lowering

0 commit comments

Comments
 (0)