Skip to content

Commit e2b7b56

Browse files
committed
Merge branch 'guoruoqian-fix_conv1d' into 'release/1.0'
feat: support aten::conv1d and aten::conv_transpose1d See merge request adlsa/TRTorch!8
2 parents 054373b + 55f90e2 commit e2b7b56

File tree

8 files changed

+178
-2
lines changed

8 files changed

+178
-2
lines changed

core/conversion/converters/impl/conv_deconv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,4 @@ auto conv_registrations TORCHTRT_UNUSED =
202202
} // namespace converters
203203
} // namespace conversion
204204
} // namespace core
205-
} // namespace torch_tensorrt
205+
} // namespace torch_tensorrt

core/lowering/lowering.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ void LowerGraph(std::shared_ptr<torch::jit::Graph>& g, LowerInfo lower_info) {
4646
passes::RemoveContiguous(g);
4747
passes::RemoveDropout(g);
4848
passes::LinearToAddMM(g);
49+
passes::Conv1DToConvolution(g);
50+
passes::ConvTransposed1DToConvolution(g);
4951
passes::Conv2DToConvolution(g);
5052
passes::Conv3DToConvolution(g);
5153
passes::FuseAddMMBranches(g);

core/lowering/passes/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ config_setting(
1010
cc_library(
1111
name = "passes",
1212
srcs = [
13+
"conv1d_to_convolution.cpp",
1314
"conv2d_to_convolution.cpp",
1415
"conv3d_to_convolution.cpp",
1516
"exception_elimination.cpp",
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
2+
3+
#include "core/util/prelude.h"
4+
5+
namespace trtorch {
6+
namespace core {
7+
namespace lowering {
8+
namespace passes {
9+
10+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
11+
std::string conv1d_pattern = R"IR(
12+
graph(%x, %w, %b, %s, %p, %d, %g):
13+
%4 : Tensor = aten::conv1d(%x, %w, %b, %s, %p, %d, %g)
14+
return (%4))IR";
15+
std::string convolution_pattern = R"IR(
16+
graph(%x, %w, %b, %s, %p, %d, %g):
17+
%1 : bool = prim::Constant[value=0]()
18+
%2 : int[] = prim::Constant[value=[0]]()
19+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %2, %g, %1, %1, %1, %1)
20+
return (%4))IR";
21+
22+
torch::jit::SubgraphRewriter map_conv1d_to_convolution;
23+
map_conv1d_to_convolution.RegisterRewritePattern(conv1d_pattern, convolution_pattern);
24+
map_conv1d_to_convolution.runOnGraph(graph);
25+
LOG_GRAPH("Post map conv1d -> _convolution: " << *graph);
26+
}
27+
28+
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph) {
29+
std::string conv_transpose1d_pattern = R"IR(
30+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
31+
%4 : Tensor = aten::conv_transpose1d(%x, %w, %b, %s, %p, %o, %g, %d)
32+
return (%4))IR";
33+
std::string convolution_pattern = R"IR(
34+
graph(%x, %w, %b, %s, %p, %o, %g, %d):
35+
%1 : bool = prim::Constant[value=1]()
36+
%2 : bool = prim::Constant[value=1]()
37+
%4 : Tensor = aten::_convolution(%x, %w, %b, %s, %p, %d, %1, %o, %g, %2, %2, %2, %2)
38+
return (%4))IR";
39+
40+
torch::jit::SubgraphRewriter map_conv_transpose1d_to_convolution;
41+
map_conv_transpose1d_to_convolution.RegisterRewritePattern(conv_transpose1d_pattern, convolution_pattern);
42+
map_conv_transpose1d_to_convolution.runOnGraph(graph);
43+
LOG_GRAPH("Post map conv_transpose1d -> _convolution: " << *graph);
44+
}
45+
46+
} // namespace passes
47+
} // namespace lowering
48+
} // namespace core
49+
} // namespace trtorch

core/lowering/passes/passes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ void NotateModuleForFallback(
1212
std::string mod_name,
1313
std::string method_name,
1414
std::unordered_set<std::string> forced_fallback_modules);
15+
void Conv1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
16+
void ConvTransposed1DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1517
void Conv2DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1618
void Conv3DToConvolution(std::shared_ptr<torch::jit::Graph>& graph);
1719
void FuseAddMMBranches(std::shared_ptr<torch::jit::Graph> graph);

tests/core/conversion/converters/test_conv_deconv.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -654,4 +654,4 @@ TEST(Converters, ATenConvTransposeWithGroupConvertsCorrectly) {
654654
auto trt = trt_results[0].reshape(jit_results[0].sizes());
655655

656656
ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6));
657-
}
657+
}

tests/core/lowering/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ cc_test(
2626
]
2727
)
2828

29+
lowering_test(
30+
name = "test_conv1d_pass",
31+
)
32+
2933
lowering_test(
3034
name = "test_remove_contiguous_pass",
3135
)
@@ -61,6 +65,7 @@ lowering_test(
6165
test_suite(
6266
name = "lowering_tests",
6367
tests = [
68+
":test_conv1d_pass",
6469
":test_linear_to_addmm",
6570
":test_module_fallback_passes",
6671
":test_operator_aliasing_pass",
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "core/lowering/passes/passes.h"
4+
#include "gtest/gtest.h"
5+
#include "tests/util/util.h"
6+
#include "torch/csrc/jit/ir/irparser.h"
7+
#include "torch/csrc/jit/ir/subgraph_matcher.h"
8+
9+
TEST(LoweringPasses, Conv1dCorrectly) {
10+
const auto source_graph = R"IR(
11+
graph(%0 : Tensor,
12+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
13+
%2 : Float(3)):
14+
%4 : int = prim::Constant[value=0]()
15+
%5 : int = prim::Constant[value=1]()
16+
%6 : int = prim::Constant[value=1]()
17+
%stride : int[] = prim::ListConstruct(%6)
18+
%padding : int[] = prim::ListConstruct(%4)
19+
%dilation : int[] = prim::ListConstruct(%5)
20+
%12 : Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
21+
return (%12))IR";
22+
23+
const auto target_graph = R"IR(
24+
graph(%0 : Tensor,
25+
%1 : Float(4, 3, 3, strides=[9, 3, 1]),
26+
%2 : Float(3)):
27+
%3 : bool = prim::Constant[value=0]()
28+
%4 : int = prim::Constant[value=0]()
29+
%5 : int = prim::Constant[value=1]()
30+
%6 : int = prim::Constant[value=1]()
31+
%stride : int[] = prim::ListConstruct(%6)
32+
%padding : int[] = prim::ListConstruct(%4)
33+
%dilation : int[] = prim::ListConstruct(%5)
34+
%output_padding : int[] = prim::Constant[value=[0]]()
35+
%12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
36+
return (%12))IR";
37+
38+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
39+
auto sg = std::make_shared<torch::jit::Graph>();
40+
torch::jit::parseIR(source_graph, &*sg);
41+
trtorch::core::lowering::passes::Conv1DToConvolution(sg);
42+
43+
auto tg = std::make_shared<torch::jit::Graph>();
44+
torch::jit::parseIR(target_graph, &*tg);
45+
46+
auto in = at::randint(1, 2, {1, 3, 3}, {at::kCUDA});
47+
auto w = at::randint(1, 2, {4, 3, 3}, {at::kCUDA});
48+
auto b = at::randint(1, 10, {4}, {at::kCUDA});
49+
50+
auto trt_in = at::clone(in);
51+
auto trt_w = at::clone(w);
52+
auto trt_b = at::clone(b);
53+
auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b});
54+
auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in});
55+
56+
params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b});
57+
auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in});
58+
59+
ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
60+
}
61+
62+
TEST(LoweringPasses, ConvTransposed1dCorrectly) {
63+
const auto source_graph = R"IR(
64+
graph(%0 : Tensor,
65+
%1 : Float(8, 3, 3, strides=[9, 3, 1]),
66+
%2 : Float(3)):
67+
%3 : int = prim::Constant[value=1]()
68+
%4 : int = prim::Constant[value=0]()
69+
%5 : int = prim::Constant[value=1]()
70+
%6 : int = prim::Constant[value=0]()
71+
%stride : int[] = prim::ListConstruct(%3)
72+
%padding : int[] = prim::ListConstruct(%4)
73+
%dilation : int[] = prim::ListConstruct(%5)
74+
%output_padding : int[] = prim::ListConstruct(%6)
75+
%12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %stride, %padding, %output_padding, %3, %dilation)
76+
return (%12))IR";
77+
78+
const auto target_graph = R"IR(
79+
graph(%0 : Tensor,
80+
%1 : Float(8, 3, 3, strides=[9, 3, 1]),
81+
%2 : Float(3)):
82+
%3 : int = prim::Constant[value=1]()
83+
%4 : int = prim::Constant[value=0]()
84+
%5 : int = prim::Constant[value=1]()
85+
%6 : int = prim::Constant[value=0]()
86+
%7 : bool = prim::Constant[value=0]()
87+
%8 : bool = prim::Constant[value=1]()
88+
%stride : int[] = prim::ListConstruct(%3)
89+
%padding : int[] = prim::ListConstruct(%4)
90+
%dilation : int[] = prim::ListConstruct(%5)
91+
%output_padding : int[] = prim::ListConstruct(%6)
92+
%12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7)
93+
return (%12))IR";
94+
95+
trtorch::core::util::logging::get_logger().set_reportable_log_level(trtorch::core::util::logging::LogLevel::kGRAPH);
96+
auto sg = std::make_shared<torch::jit::Graph>();
97+
torch::jit::parseIR(source_graph, &*sg);
98+
trtorch::core::lowering::passes::ConvTransposed1DToConvolution(sg);
99+
100+
auto tg = std::make_shared<torch::jit::Graph>();
101+
torch::jit::parseIR(target_graph, &*tg);
102+
103+
auto in = at::randint(1, 2, {1, 8, 3}, {at::kCUDA});
104+
auto w = at::randint(1, 2, {8, 3, 3}, {at::kCUDA});
105+
auto b = at::randint(1, 10, {3}, {at::kCUDA});
106+
107+
auto trt_in = at::clone(in);
108+
auto trt_w = at::clone(w);
109+
auto trt_b = at::clone(b);
110+
auto params = trtorch::core::conversion::get_named_params(sg->inputs(), {trt_w, trt_b});
111+
auto trt_results_sg = trtorch::tests::util::RunGraphEngine(sg, params, {trt_in});
112+
113+
params = trtorch::core::conversion::get_named_params(tg->inputs(), {trt_w, trt_b});
114+
auto trt_results_tg = trtorch::tests::util::RunGraphEngine(tg, params, {trt_in});
115+
116+
ASSERT_TRUE(trtorch::tests::util::almostEqual(trt_results_sg[0], trt_results_tg[0], 2e-6));
117+
}

0 commit comments

Comments
 (0)