diff --git a/core/conversion/converters/impl/squeeze.cpp b/core/conversion/converters/impl/squeeze.cpp index 8f4b437a88..f934934af4 100644 --- a/core/conversion/converters/impl/squeeze.cpp +++ b/core/conversion/converters/impl/squeeze.cpp @@ -14,35 +14,57 @@ namespace converters { namespace impl { namespace { -auto squeeze_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns().pattern( - {"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto self = args[0].ITensorOrFreeze(ctx); - auto dim = args[1].unwrapToInt(); +auto squeeze_registrations TORCHTRT_UNUSED = + RegisterNodeConversionPatterns() + .pattern( + {"aten::squeeze.dim(Tensor(a) self, int dim) -> (Tensor(a))", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto dim = args[1].unwrapToInt(); - auto selfDim = util::toVec(self->getDimensions()); - if (dim < 0) { - dim = selfDim.size() + dim; - } + auto selfDim = util::toVec(self->getDimensions()); + if (dim < 0) { + dim = selfDim.size() + dim; + } - if (selfDim[dim] != 1) { - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self); + if (selfDim[dim] != 1) { + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], self); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; - } + return true; + } - auto shuffle_layer = ctx->net->addShuffle(*self); - TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim)); + auto shuffle_layer = ctx->net->addShuffle(*self); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(util::squeezeDims(self->getDimensions(), dim)); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0)); + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_layer->getOutput(0)); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - return true; - }}); + return true; + }}) + .pattern( + {"aten::squeeze(Tensor(a) self) -> (Tensor(a))", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto self = args[0].ITensorOrFreeze(ctx); + auto self_dims = self->getDimensions(); + auto out = self; + auto squeeze_dims = util::squeezeAllDims(self_dims); + if (squeeze_dims != self_dims) { + auto shuffle_layer = ctx->net->addShuffle(*self); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(squeeze_dims); + out = shuffle_layer->getOutput(0); + } + + auto trt_out = ctx->AssociateValueAndTensor(n->outputs()[0], out); + + LOG_DEBUG("Output tensor shape: " << trt_out->getDimensions()); + + return true; + }}); } // namespace } // namespace impl diff --git a/core/util/trt_util.cpp b/core/util/trt_util.cpp index 208d93296c..d320992a70 100644 --- a/core/util/trt_util.cpp +++ b/core/util/trt_util.cpp @@ -196,6 +196,19 @@ nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros) { return dims; } +nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims) { + nvinfer1::Dims dims; + int j = 0; + for (int i = 0; i < d.nbDims; i++) { + if (d.d[i] != 1) { + dims.d[j++] = (use_zeros_for_unknown_dims && d.d[i] == -1) ? 0 : d.d[i]; + } + } + dims.nbDims = j; + + return dims; +} + std::vector toVec(nvinfer1::Dims d) { std::vector dims; for (int i = 0; i < d.nbDims; i++) { diff --git a/core/util/trt_util.h b/core/util/trt_util.h index 2d72178098..355b0d13cc 100644 --- a/core/util/trt_util.h +++ b/core/util/trt_util.h @@ -137,6 +137,7 @@ nvinfer1::Dims toDimsTailPad(c10::List l, uint64_t pad_to); nvinfer1::Dims unpadDims(const nvinfer1::Dims& d); nvinfer1::Dims unsqueezeDims(const nvinfer1::Dims& d, int pos, int val = 1, bool use_zeros = true); nvinfer1::Dims squeezeDims(const nvinfer1::Dims& d, int pos, bool use_zeros = true); +nvinfer1::Dims squeezeAllDims(const nvinfer1::Dims& d, bool use_zeros_for_unknown_dims = true); nvinfer1::Dims toDims(c10::IntArrayRef l); nvinfer1::Dims toDims(c10::List l); nvinfer1::DimsHW toDimsHW(c10::List l); diff --git a/tests/core/conversion/converters/test_squeeze.cpp b/tests/core/conversion/converters/test_squeeze.cpp index 5c40848744..09334710b6 100644 --- a/tests/core/conversion/converters/test_squeeze.cpp +++ b/tests/core/conversion/converters/test_squeeze.cpp @@ -56,3 +56,29 @@ TEST(Converters, ATenSqueezeDontNeedSqueezeConvertsCorrectly) { ASSERT_TRUE( torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6)); } + +TEST(Converters, ATenSqueezeNoDimConvertsCorrectly) { + const auto graph = R"IR( + graph(%0 : Tensor): + %1 : Tensor = aten::squeeze(%0) + return (%1))IR"; + + auto g = std::make_shared(); + torch::jit::parseIR(graph, g.get()); + + auto validate_squeeze_with_input = [&g](const at::Tensor& in) { + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in}); + + params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in}); + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6)); + }; + + validate_squeeze_with_input(at::randint(1, 10, {2, 1, 3, 3}, {at::kCUDA})); + validate_squeeze_with_input(at::randint(1, 10, {1, 1, 1, 3}, {at::kCUDA})); + validate_squeeze_with_input(at::randint(1, 10, {1, 10, 1, 3}, {at::kCUDA})); + validate_squeeze_with_input(at::randint(1, 10, {2, 10, 3, 3}, {at::kCUDA})); + validate_squeeze_with_input(at::randint(1, 10, {1, 1}, {at::kCUDA})); + validate_squeeze_with_input(at::randint(1, 10, {1}, {at::kCUDA})); +}