1
+ #include < string>
2
+ #include " core/compiler.h"
3
+ #include " core/lowering/passes/passes.h"
4
+ #include " gtest/gtest.h"
5
+ #include " tests/util/util.h"
6
+ #include " torch/csrc/jit/ir/irparser.h"
7
+ #include " torch/csrc/jit/ir/subgraph_matcher.h"
8
+
9
+ TEST (LoweringPasses, Conv1dCorrectly) {
10
+ const auto source_graph = R"IR(
11
+ graph(%0 : Tensor,
12
+ %1 : Float(4, 3, 3, strides=[9, 3, 1]),
13
+ %2 : Float(3)):
14
+ %4 : int = prim::Constant[value=0]()
15
+ %5 : int = prim::Constant[value=1]()
16
+ %6 : int = prim::Constant[value=1]()
17
+ %stride : int[] = prim::ListConstruct(%6)
18
+ %padding : int[] = prim::ListConstruct(%4)
19
+ %dilation : int[] = prim::ListConstruct(%5)
20
+ %12 : Tensor = aten::conv1d(%0, %1, %2, %stride, %padding, %dilation, %6)
21
+ return (%12))IR" ;
22
+
23
+ const auto target_graph = R"IR(
24
+ graph(%0 : Tensor,
25
+ %1 : Float(4, 3, 3, strides=[9, 3, 1]),
26
+ %2 : Float(3)):
27
+ %3 : bool = prim::Constant[value=0]()
28
+ %4 : int = prim::Constant[value=0]()
29
+ %5 : int = prim::Constant[value=1]()
30
+ %6 : int = prim::Constant[value=1]()
31
+ %stride : int[] = prim::ListConstruct(%6)
32
+ %padding : int[] = prim::ListConstruct(%4)
33
+ %dilation : int[] = prim::ListConstruct(%5)
34
+ %output_padding : int[] = prim::Constant[value=[0]]()
35
+ %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %3, %output_padding, %6, %3, %3, %3, %3)
36
+ return (%12))IR" ;
37
+
38
+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
39
+ auto sg = std::make_shared<torch::jit::Graph>();
40
+ torch::jit::parseIR (source_graph, &*sg);
41
+ trtorch::core::lowering::passes::Conv1DToConvolution (sg);
42
+
43
+ auto tg = std::make_shared<torch::jit::Graph>();
44
+ torch::jit::parseIR (target_graph, &*tg);
45
+
46
+ auto in = at::randint (1 , 2 , {1 , 3 , 3 }, {at::kCUDA });
47
+ auto w = at::randint (1 , 2 , {4 , 3 , 3 }, {at::kCUDA });
48
+ auto b = at::randint (1 , 10 , {4 }, {at::kCUDA });
49
+
50
+ auto trt_in = at::clone (in);
51
+ auto trt_w = at::clone (w);
52
+ auto trt_b = at::clone (b);
53
+ auto params = trtorch::core::conversion::get_named_params (sg->inputs (), {trt_w, trt_b});
54
+ auto trt_results_sg = trtorch::tests::util::RunGraphEngine (sg, params, {trt_in});
55
+
56
+ params = trtorch::core::conversion::get_named_params (tg->inputs (), {trt_w, trt_b});
57
+ auto trt_results_tg = trtorch::tests::util::RunGraphEngine (tg, params, {trt_in});
58
+
59
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
60
+ }
61
+
62
+ TEST (LoweringPasses, ConvTransposed1dCorrectly) {
63
+ const auto source_graph = R"IR(
64
+ graph(%0 : Tensor,
65
+ %1 : Float(8, 3, 3, strides=[9, 3, 1]),
66
+ %2 : Float(3)):
67
+ %3 : int = prim::Constant[value=1]()
68
+ %4 : int = prim::Constant[value=0]()
69
+ %5 : int = prim::Constant[value=1]()
70
+ %6 : int = prim::Constant[value=0]()
71
+ %stride : int[] = prim::ListConstruct(%3)
72
+ %padding : int[] = prim::ListConstruct(%4)
73
+ %dilation : int[] = prim::ListConstruct(%5)
74
+ %output_padding : int[] = prim::ListConstruct(%6)
75
+ %12 : Tensor = aten::conv_transpose1d(%0, %1, %2, %stride, %padding, %output_padding, %3, %dilation)
76
+ return (%12))IR" ;
77
+
78
+ const auto target_graph = R"IR(
79
+ graph(%0 : Tensor,
80
+ %1 : Float(8, 3, 3, strides=[9, 3, 1]),
81
+ %2 : Float(3)):
82
+ %3 : int = prim::Constant[value=1]()
83
+ %4 : int = prim::Constant[value=0]()
84
+ %5 : int = prim::Constant[value=1]()
85
+ %6 : int = prim::Constant[value=0]()
86
+ %7 : bool = prim::Constant[value=0]()
87
+ %8 : bool = prim::Constant[value=1]()
88
+ %stride : int[] = prim::ListConstruct(%3)
89
+ %padding : int[] = prim::ListConstruct(%4)
90
+ %dilation : int[] = prim::ListConstruct(%5)
91
+ %output_padding : int[] = prim::ListConstruct(%6)
92
+ %12 : Tensor = aten::_convolution(%0, %1, %2, %stride, %padding, %dilation, %8, %output_padding, %5, %7, %7, %7, %7)
93
+ return (%12))IR" ;
94
+
95
+ trtorch::core::util::logging::get_logger ().set_reportable_log_level (trtorch::core::util::logging::LogLevel::kGRAPH );
96
+ auto sg = std::make_shared<torch::jit::Graph>();
97
+ torch::jit::parseIR (source_graph, &*sg);
98
+ trtorch::core::lowering::passes::ConvTransposed1DToConvolution (sg);
99
+
100
+ auto tg = std::make_shared<torch::jit::Graph>();
101
+ torch::jit::parseIR (target_graph, &*tg);
102
+
103
+ auto in = at::randint (1 , 2 , {1 , 8 , 3 }, {at::kCUDA });
104
+ auto w = at::randint (1 , 2 , {8 , 3 , 3 }, {at::kCUDA });
105
+ auto b = at::randint (1 , 10 , {3 }, {at::kCUDA });
106
+
107
+ auto trt_in = at::clone (in);
108
+ auto trt_w = at::clone (w);
109
+ auto trt_b = at::clone (b);
110
+ auto params = trtorch::core::conversion::get_named_params (sg->inputs (), {trt_w, trt_b});
111
+ auto trt_results_sg = trtorch::tests::util::RunGraphEngine (sg, params, {trt_in});
112
+
113
+ params = trtorch::core::conversion::get_named_params (tg->inputs (), {trt_w, trt_b});
114
+ auto trt_results_tg = trtorch::tests::util::RunGraphEngine (tg, params, {trt_in});
115
+
116
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (trt_results_sg[0 ], trt_results_tg[0 ], 2e-6 ));
117
+ }
0 commit comments