Skip to content

Commit 169c5bc

Browse files
committed
chore: Fix merge conflicts
Signed-off-by: Dheeraj Peri <[email protected]>
2 parents 587b3a1 + 54d5e4c commit 169c5bc

File tree

11 files changed

+860
-282
lines changed

11 files changed

+860
-282
lines changed

core/conversion/conversion.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,15 @@ c10::optional<torch::jit::IValue> EvaluateNode(ConversionCtx* ctx, const torch::
4545
if (result) {
4646
// WARN: If the converter returns None then should pass through
4747
// but if repeated dep this section will get called each time
48-
ctx->evaluated_value_map[eval_in] = std::move(result.value());
49-
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
48+
auto val = result.value();
49+
if (val.isCustomClass()) {
50+
auto cont = val.toCustomClass<TensorContainer>();
51+
ctx->AssociateValueAndTensor(eval_in, cont->tensor());
52+
eval_args[eval_in] = ctx->value_tensor_map[eval_in];
53+
} else {
54+
ctx->AssociateValueAndIValue(eval_in, val);
55+
eval_args[eval_in] = &(ctx->evaluated_value_map[eval_in]);
56+
}
5057
}
5158
} else {
5259
TRTORCH_THROW_ERROR(
@@ -374,6 +381,11 @@ void ConvertBlockToNetDef(
374381
} else {
375382
TRTORCH_THROW_ERROR("Unsupported return type for evaluated node");
376383
}
384+
} else if (eval.value().isCustomClass()) {
385+
auto container = eval.value().toCustomClass<TensorContainer>();
386+
auto tensor = container->tensor();
387+
LOG_DEBUG(ctx->logger, "Found the value to be an ITensor of shape: " << tensor->getDimensions());
388+
ctx->AssociateValueAndTensor(n->output(0), tensor);
377389
} else if (!eval.value().isTensor()) {
378390
LOG_DEBUG(ctx->logger, "Found the value to be: " << eval.value());
379391
ctx->AssociateValueAndIValue(n->output(0), eval.value());

core/conversion/converters/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ cc_library(
3535
"impl/batch_norm.cpp",
3636
"impl/concat.cpp",
3737
"impl/constant.cpp",
38+
"impl/constant_pad.cpp",
3839
"impl/conv_deconv.cpp",
3940
"impl/cumsum.cpp",
4041
"impl/element_wise.cpp",
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
#include <ATen/ATen.h>
2+
#include <vector>
3+
#include "NvInfer.h"
4+
#include "core/conversion/converters/converters.h"
5+
#include "core/util/prelude.h"
6+
#include "torch/torch.h"
7+
8+
namespace trtorch {
9+
namespace core {
10+
namespace conversion {
11+
namespace converters {
12+
namespace impl {
13+
namespace {
14+
15+
auto constant_pad_registrations TRTORCH_UNUSED = RegisterNodeConversionPatterns().pattern(
16+
{"aten::constant_pad_nd(Tensor self, int[] pad, Scalar value=0) -> (Tensor)",
17+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
18+
auto in = args[0].ITensor();
19+
auto inDims = in->getDimensions();
20+
int64_t inRank = inDims.nbDims;
21+
auto padding = args[1].unwrapToIntList().vec();
22+
int64_t padSize = padding.size();
23+
auto value = args[2].unwrapToScalar().to<float>();
24+
25+
TRTORCH_CHECK(padSize % 2 == 0, "Length of pad must be even but instead it equals " << padSize);
26+
27+
int64_t l_pad = padSize / 2;
28+
TRTORCH_CHECK(
29+
inRank >= (int64_t)l_pad,
30+
"Length of pad should be no more than twice the number of "
31+
"dimensions of the input. Pad length is "
32+
<< padSize << "while the input has " << inRank << "dimensions.");
33+
34+
// TODO negative padding. When the pad is negative, we need to crop the image.
35+
36+
std::vector<nvinfer1::ITensor*> tensors_vec;
37+
// input: (N, C, D_in, H_in, W_in).
38+
// padding: (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back)
39+
// When axis is inRank - 1, making W_out = W_in + padding_left + padding_right.
40+
// When axis is inRank - 2, making H_out = H_in + padding_top + padding_bottom.
41+
// When axis is inRank - 3, making D_out = D_in + padding_front + padding_back.
42+
for (int64_t i = 0; i < l_pad; i++) {
43+
int64_t axis = inRank - (i + 1); // axis = {inRank - 1, inRank - 2, inRank - 3}
44+
int64_t padding_index = i * 2;
45+
46+
if (padding[padding_index] > 0) { // left/top/front padding value
47+
tensors_vec.clear();
48+
if (ctx->input_is_dynamic) {
49+
at::Tensor left_indices = torch::tensor({0}, torch::kInt32);
50+
auto indicesTensor = tensor_to_const(ctx, left_indices);
51+
auto left_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
52+
auto left_gather_out = left_gather_layer->getOutput(0);
53+
54+
// fill the left_gather_out with value
55+
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
56+
auto shape_gather_out = ctx->net->addShape(*left_gather_out)->getOutput(0);
57+
fill_layer->setInput(0, *shape_gather_out);
58+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
59+
auto valueTensor = tensor_to_const(ctx, value_tensor);
60+
fill_layer->setInput(1, *valueTensor);
61+
at::Tensor delta_tensor = torch::zeros(inRank);
62+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
63+
fill_layer->setInput(2, *deltaTensor);
64+
auto padTensor = fill_layer->getOutput(0);
65+
66+
for (int i = 0; i < padding[padding_index]; i++) {
67+
tensors_vec.push_back(padTensor);
68+
}
69+
} else {
70+
inDims.d[axis] = padding[padding_index];
71+
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
72+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
73+
auto valueTensor = tensor_to_const(ctx, value_tensor);
74+
fill_layer->setInput(1, *valueTensor);
75+
at::Tensor delta_tensor = torch::zeros(inRank);
76+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
77+
fill_layer->setInput(2, *deltaTensor);
78+
auto padTensor = fill_layer->getOutput(0);
79+
80+
tensors_vec.push_back(padTensor);
81+
}
82+
83+
tensors_vec.push_back(in);
84+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
85+
concat_layer->setAxis(axis);
86+
in = concat_layer->getOutput(0);
87+
inDims = in->getDimensions();
88+
}
89+
90+
if (padding[padding_index + 1] > 0) { // right/bottom/back padding value
91+
tensors_vec.clear();
92+
tensors_vec.push_back(in);
93+
94+
nvinfer1::ITensor* indicesTensor = NULL;
95+
if (inDims.d[axis] == -1) {
96+
auto shapeTensor = ctx->net->addShape(*in)->getOutput(0);
97+
at::Tensor dimValue = torch::tensor({axis}, torch::kInt32);
98+
auto dimTensor = tensor_to_const(ctx, dimValue);
99+
indicesTensor = ctx->net->addGather(*shapeTensor, *dimTensor, 0)->getOutput(0);
100+
auto oneTensor = tensor_to_const(ctx, torch::tensor({1}, torch::kInt32));
101+
indicesTensor = ctx->net->addElementWise(*indicesTensor, *oneTensor, nvinfer1::ElementWiseOperation::kSUB)
102+
->getOutput(0);
103+
} else {
104+
auto indices = torch::tensor({inDims.d[axis] - 1}, torch::kInt32);
105+
indicesTensor = tensor_to_const(ctx, indices);
106+
}
107+
auto right_gather_layer = ctx->net->addGather(*in, *indicesTensor, axis);
108+
auto right_gather_out = right_gather_layer->getOutput(0);
109+
110+
if (ctx->input_is_dynamic) {
111+
// fill the right_gather_out with value
112+
auto fill_layer = ctx->net->addFill(nvinfer1::Dims{1, {1}}, nvinfer1::FillOperation::kLINSPACE);
113+
auto shape_gather_out = ctx->net->addShape(*right_gather_out)->getOutput(0);
114+
fill_layer->setInput(0, *shape_gather_out);
115+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
116+
auto valueTensor = tensor_to_const(ctx, value_tensor);
117+
fill_layer->setInput(1, *valueTensor);
118+
at::Tensor delta_tensor = torch::zeros(inRank);
119+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
120+
fill_layer->setInput(2, *deltaTensor);
121+
auto padTensor = fill_layer->getOutput(0);
122+
123+
for (int i = 0; i < padding[padding_index + 1]; i++) {
124+
tensors_vec.push_back(padTensor);
125+
}
126+
} else {
127+
inDims.d[axis] = padding[padding_index + 1];
128+
auto fill_layer = ctx->net->addFill(inDims, nvinfer1::FillOperation::kLINSPACE);
129+
at::Tensor value_tensor = torch::tensor(value, torch::kFloat32);
130+
auto valueTensor = tensor_to_const(ctx, value_tensor);
131+
fill_layer->setInput(1, *valueTensor);
132+
at::Tensor delta_tensor = torch::zeros(inRank);
133+
auto deltaTensor = tensor_to_const(ctx, delta_tensor);
134+
fill_layer->setInput(2, *deltaTensor);
135+
auto padTensor = fill_layer->getOutput(0);
136+
137+
tensors_vec.push_back(padTensor);
138+
}
139+
auto concat_layer = ctx->net->addConcatenation(tensors_vec.data(), tensors_vec.size());
140+
concat_layer->setAxis(axis);
141+
in = concat_layer->getOutput(0);
142+
inDims = in->getDimensions();
143+
}
144+
}
145+
146+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in);
147+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
148+
return true;
149+
}});
150+
151+
} // namespace
152+
} // namespace impl
153+
} // namespace converters
154+
} // namespace conversion
155+
} // namespace core
156+
} // namespace trtorch

core/conversion/evaluators/aten.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ auto aten_registrations TRTORCH_UNUSED =
128128
auto out_tensor = torch::zeros(args.at(n->input(0)).unwrapToIntList().vec(), options);
129129
return out_tensor;
130130
}})
131+
.evaluator({c10::Symbol::fromQualString("aten::ones"),
132+
// aten::ones(int[] size, *, int? dtype=None, int? layout=None,
133+
// Device? device=None, bool? pin_memory=None) -> (Tensor)
134+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
135+
auto options = torch::TensorOptions().layout(torch::kStrided).device(torch::kCUDA);
136+
137+
// Input 1 here is the dtype
138+
if (!args.at(n->input(1)).isNone() && !args.at(n->input(1)).IValue()->isNone()) {
139+
options = options.dtype(c10::ScalarType(args.at(n->input(1)).unwrapToInt()));
140+
}
141+
142+
auto out_tensor = torch::ones(args.at(n->input(0)).unwrapToIntList().vec(), options);
143+
return out_tensor;
144+
}})
131145
.evaluator({c10::Symbol::fromQualString("aten::slice"),
132146
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
133147
c10::List<c10::IValue> list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
@@ -216,9 +230,17 @@ auto aten_registrations TRTORCH_UNUSED =
216230
.evaluator({c10::Symbol::fromQualString("aten::append"),
217231
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
218232
auto list = args.at(n->input(0)).IValue()->to<c10::List<c10::IValue>>();
219-
auto el = args.at(n->input(1)).IValue();
220233

221-
list.push_back(std::move(*el));
234+
if (args.at(n->input(1)).isITensor()) {
235+
auto tensor_holder = TensorContainer();
236+
tensor_holder.hold_tensor(args.at(n->input(1)).ITensor());
237+
auto el = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
238+
list.push_back(std::move(el));
239+
} else {
240+
auto el = args.at(n->input(1)).IValue();
241+
list.push_back(std::move(*el));
242+
}
243+
222244
return list;
223245
},
224246
EvalOptions().validSchemas({

core/plugins/impl/interpolate_plugin.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ std::vector<int64_t> InterpolatePlugin::getOutputSize() {
105105
return size_;
106106
}
107107

108+
108109
int InterpolatePlugin::getNbOutputs() const noexcept {
109110
if (mode_ == "adaptive_max_pool2d") {
110111
return 2;
@@ -169,6 +170,7 @@ nvinfer1::DataType InterpolatePlugin::getOutputDataType(int index, const nvinfer
169170
return nvinfer1::DataType::kFLOAT;
170171
}
171172

173+
172174
int InterpolatePlugin::initialize() noexcept {
173175
return 0;
174176
}
@@ -206,6 +208,9 @@ bool InterpolatePlugin::supportsFormatCombination(
206208
const nvinfer1::PluginTensorDesc* inOut,
207209
int nbInputs,
208210
int nbOutputs) noexcept {
211+
212+
TRTORCH_ASSERT(nbInputs == 1, "Expected a single tensor as input to interpolate plugin");
213+
209214
if (mode_ == "adaptive_max_pool2d") {
210215
TRTORCH_ASSERT(nbOutputs == 2, "Expected 2 tensors as output to interpolate plugin");
211216
TRTORCH_ASSERT(0 <= pos && pos <= 2, "There should be exactly 3 connections to the plugin - 1 input, 2 output");

0 commit comments

Comments
 (0)