diff --git a/core/conversion/converters/converter_util.cpp b/core/conversion/converters/converter_util.cpp index 9312706b47..1346f7eeb3 100644 --- a/core/conversion/converters/converter_util.cpp +++ b/core/conversion/converters/converter_util.cpp @@ -199,6 +199,131 @@ nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std:: return out; } +// clamp x to [lower_bound, upper_bound] +nvinfer1::ITensor* clamp( + ConversionCtx* ctx, + nvinfer1::ITensor* x, + nvinfer1::ITensor* lower_bound, + nvinfer1::ITensor* upper_bound, + std::string const& name) { + + auto max_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, x, lower_bound, "max layer for " + name); + TORCHTRT_CHECK(max_layer, "Unable to create max layer for clamp"); + LOG_DEBUG(ctx->logger, "Create " << max_layer->getName() << " for clamp"); + auto max_itensor = max_layer->getOutput(0); + + auto min_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMIN, max_itensor, upper_bound, "min layer for " + name); + TORCHTRT_CHECK(min_layer, "Unable to create min layer for clamp"); + LOG_DEBUG(ctx->logger, "Create " << min_layer->getName() << " for clamp"); + auto min_itensor = min_layer->getOutput(0); + return min_itensor; +} + +// clamp x to [0, input_dim] +nvinfer1::ITensor* clamp_to_input_dim( + ConversionCtx* ctx, + nvinfer1::ITensor* x, + nvinfer1::ITensor* input_dim, + int nbdims, + std::string const& name) { + + auto zero = torch::zeros({nbdims}).to(torch::kI32); + auto zero_itensor = tensor_to_const(ctx, zero); + auto one = torch::ones({nbdims}).to(torch::kI32); + auto one_itensor = tensor_to_const(ctx, one); + + auto upper_bound_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, input_dim, one_itensor, "sub layer for " + name); + TORCHTRT_CHECK(upper_bound_layer, "Unable to create sub layer for clamp to inputDim"); + LOG_DEBUG(ctx->logger, "Create " << upper_bound_layer->getName() << " for clamp to inputDim"); + auto upper_bound = upper_bound_layer->getOutput(0); + + auto max_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMAX, x, zero_itensor, "max layer for " + name); + TORCHTRT_CHECK(max_layer, "Unable to create max_layer for clamp to inputDim"); + LOG_DEBUG(ctx->logger, "Create " << max_layer->getName() << " for clamp to inputDim"); + auto max_itensor = max_layer->getOutput(0); + + auto min_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kMIN, max_itensor, upper_bound, "min layer for " + name); + TORCHTRT_CHECK(min_layer, "Unable to create min_layer for clamp to inputDim"); + LOG_DEBUG(ctx->logger, "Create " << min_layer->getName() << " for clamp to inputDim"); + auto min_itensor = min_layer->getOutput(0); + return min_itensor; +} + +// return indices < 0 ? inputDims + indices : indices +nvinfer1::ITensor* normalize_indices( + ConversionCtx* ctx, + nvinfer1::ITensor* input_dim, + nvinfer1::ITensor* indices, + int nbdims, + std::string const& name) { + + auto zero = torch::zeros({nbdims}).to(torch::kI32); + auto neg = -torch::ones({nbdims}).to(torch::kI32); + auto zero_itensor = tensor_to_const(ctx, zero); + auto neg_itensor = tensor_to_const(ctx, neg); + // find the indices that = -1 + auto signs = clamp(ctx, indices, neg_itensor, zero_itensor, "clamp layer for " + name); + + // get the inputDim value where indices == -1, else 0 + auto mul = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kPROD, signs, input_dim, "prod layer for " + name); + TORCHTRT_CHECK(mul, "Unable to create mul layer in normalize_indices"); + LOG_DEBUG(ctx->logger, "Create " << mul->getName() << " for normalize_indices"); + auto mul_itensor = mul->getOutput(0); + + // add the inputDim value to indices where indices == -1 + auto sub = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, indices, mul_itensor, "sub layer for " + name); + TORCHTRT_CHECK(sub, "Unable to create sub layer in normalize_indices"); + LOG_DEBUG(ctx->logger, "Create " << sub->getName() << " for normalize_indices"); + auto sub_itensor = sub->getOutput(0); + return sub_itensor; +} + +std::vector normalize_start_and_end( + ConversionCtx* ctx, + nvinfer1::ITensor* in_shape, + nvinfer1::ITensor* in_start, + nvinfer1::ITensor* in_end, + int nbdims, + std::string const& name) { + auto start = normalize_indices(ctx, in_shape, in_start, nbdims, "normalize start of " + name); + auto out_start = clamp_to_input_dim(ctx, start, in_shape, nbdims, "clamp start to inputDim for " + name); + auto end = normalize_indices(ctx, in_shape, in_end, nbdims, "normalize end of " + name); + auto out_end = clamp_to_input_dim(ctx, end, in_shape, nbdims, "clamp end to inputDim for " + name); + std::vector outputs; + outputs.push_back(out_start); + outputs.push_back(out_end); + return outputs; +} + +// size = (end - start) / stride + 1, where range is [start, end], end is included +nvinfer1::ITensor* get_slice_size( + ConversionCtx* ctx, + nvinfer1::ITensor* start, + nvinfer1::ITensor* end, + nvinfer1::ITensor* stride, + int nbdims, + std::string const& name) { + at::Tensor one_tensor = torch::ones({nbdims}).to(torch::kI32); + auto one_itensor = tensor_to_const(ctx, one_tensor); + + auto sub_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUB, end, start, "get_slice_size sub layer for " + name); + TORCHTRT_CHECK(sub_layer, "Unable to create sub layer in calculate_output_size"); + LOG_DEBUG(ctx->logger, "Create " << sub_layer->getName() << " for calculate_output_size"); + auto sub_itensor = sub_layer->getOutput(0); + + auto div_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, sub_itensor, stride, "get_slice_size div layer for " + name); + TORCHTRT_CHECK(div_layer, "Unable to create div layer in calculate_output_size"); + LOG_DEBUG(ctx->logger, "Create " << div_layer->getName() << " for calculate_output_size"); + auto div_itensor = div_layer->getOutput(0); + + auto add_layer = add_elementwise(ctx, nvinfer1::ElementWiseOperation::kSUM, div_itensor, one_itensor, "get_slice_size sum layer for " + name); + TORCHTRT_CHECK(add_layer, "Unable to create add layer in calculate_output_size"); + LOG_DEBUG(ctx->logger, "Create " << add_layer->getName() << " for calculate_output_size"); + auto size_itensor = add_layer->getOutput(0); + + return size_itensor; +} + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/converter_util.h b/core/conversion/converters/converter_util.h index 3f134666d6..cdf2ee5a8d 100644 --- a/core/conversion/converters/converter_util.h +++ b/core/conversion/converters/converter_util.h @@ -2,6 +2,7 @@ #include #include +#include #include "core/conversion/conversionctx/ConversionCtx.h" #include "core/conversion/converters/Weights.h" @@ -50,6 +51,35 @@ nvinfer1::ITensor* castITensor(ConversionCtx* ctx, nvinfer1::ITensor* tensor, nv // Freeze an at::Tensor in a IConstant layer nvinfer1::ITensor* tensor_to_const(ConversionCtx* ctx, at::Tensor t, const std::string& name = std::string()); +nvinfer1::ITensor* clamp( + ConversionCtx* ctx, + nvinfer1::ITensor* x, + nvinfer1::ITensor* lower_bound, + nvinfer1::ITensor* upper_bound, + std::string const& name); + +nvinfer1::ITensor* normalize_indices( + ConversionCtx* ctx, + nvinfer1::ITensor* input_dim, + nvinfer1::ITensor* indices, + std::string const& name); + +std::vector normalize_start_and_end( + ConversionCtx* ctx, + nvinfer1::ITensor* in_shape, + nvinfer1::ITensor* in_start, + nvinfer1::ITensor* in_end, + int nbdims, + std::string const& name); + +nvinfer1::ITensor* get_slice_size( + ConversionCtx* ctx, + nvinfer1::ITensor* start, + nvinfer1::ITensor* end, + nvinfer1::ITensor* stride, + int nbdims, + std::string const& name); + } // namespace converters } // namespace conversion } // namespace core diff --git a/core/conversion/converters/impl/select.cpp b/core/conversion/converters/impl/select.cpp index 27b29b4195..3599ab9939 100644 --- a/core/conversion/converters/impl/select.cpp +++ b/core/conversion/converters/impl/select.cpp @@ -103,118 +103,121 @@ nvinfer1::ITensor* roll( auto select_registrations TORCHTRT_UNUSED = RegisterNodeConversionPatterns() - .pattern({"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensorOrFreeze(ctx); - auto maxDim = static_cast(in->getDimensions().nbDims); - auto dim = args[1].unwrapToInt(); - // Handle negative axis by refering to nbDims of input Tensor - dim = dim < 0 ? dim + maxDim : dim; - auto ind = (int32_t)args[2].unwrapToInt(); - // Along the specified dimension, handle negative index by subtracting along length of dimension. - ind = ind < 0 ? ind + in->getDimensions().d[dim] : ind; - LOG_DEBUG("Gather input dimensions: " << in->getDimensions()); - LOG_DEBUG("Dimension to select: " << dim); - LOG_DEBUG("Index: " << ind); - - // index to access needs to be an at::Tensor - at::Tensor indices = torch::tensor({ind}).to(torch::kI32); - auto const_out = tensor_to_const(ctx, indices); - - // IGatherLayer takes in input tensor, the indices, and the axis - // of input tensor to take indices from - auto gather_layer = ctx->net->addGather(*in, *const_out, dim); - TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto out = gather_layer->getOutput(0); - - LOG_DEBUG("Gather tensor shape: " << out->getDimensions()); - - if (out->getDimensions().nbDims != 1) { - // IShuffleLayer removes redundant dimensions - auto shuffle_layer = ctx->net->addShuffle(*out); - TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), dim)); - shuffle_layer->setName(util::node_info(n).c_str()); - out = shuffle_layer->getOutput(0); - } - - out = ctx->AssociateValueAndTensor(n->outputs()[0], out); - - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + .pattern( + {"aten::select.int(Tensor(a) self, int dim, int index) -> (Tensor(a))", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensorOrFreeze(ctx); + auto maxDim = static_cast(in->getDimensions().nbDims); + auto dim = args[1].unwrapToInt(); + // Handle negative axis by refering to nbDims of input Tensor + dim = dim < 0 ? dim + maxDim : dim; + auto ind = (int32_t)args[2].unwrapToInt(); + // Along the specified dimension, handle negative index by subtracting along length of dimension. + ind = ind < 0 ? ind + in->getDimensions().d[dim] : ind; + LOG_DEBUG("Gather input dimensions: " << in->getDimensions()); + LOG_DEBUG("Dimension to select: " << dim); + LOG_DEBUG("Index: " << ind); + + // index to access needs to be an at::Tensor + at::Tensor indices = torch::tensor({ind}).to(torch::kI32); + auto const_out = tensor_to_const(ctx, indices); + + // IGatherLayer takes in input tensor, the indices, and the axis + // of input tensor to take indices from + auto gather_layer = ctx->net->addGather(*in, *const_out, dim); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto out = gather_layer->getOutput(0); - return true; - }}) - .pattern({"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto axis = args[1].unwrapToInt(); - auto start = (int32_t)args[2].unwrapToInt(); - auto length = (int32_t)args[3].unwrapToInt(); + LOG_DEBUG("Gather tensor shape: " << out->getDimensions()); - // index to access needs to be an at::Tensor - at::Tensor indices = torch::arange(start, start + length, 1).to(torch::kI32); - auto weights = Weights(ctx, indices); + if (out->getDimensions().nbDims != 1) { + // IShuffleLayer removes redundant dimensions + auto shuffle_layer = ctx->net->addShuffle(*out); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(util::squeezeDims(out->getDimensions(), dim)); + shuffle_layer->setName(util::node_info(n).c_str()); + out = shuffle_layer->getOutput(0); + } - // IConstantLayer to convert indices from Weights to ITensor - auto const_layer = ctx->net->addConstant(weights.shape, weights.data); - TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n); - auto const_out = const_layer->getOutput(0); + out = ctx->AssociateValueAndTensor(n->outputs()[0], out); - // IGatherLayer takes in input tensor, the indices, and the axis - // of input tensor to take indices from - auto gather_layer = ctx->net->addGather(*in, *const_out, axis); - TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto gather_out = gather_layer->getOutput(0); + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - // IShuffleLayer removes redundant dimensions - auto shuffle_layer = ctx->net->addShuffle(*gather_out); - TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions())); - shuffle_layer->setName(util::node_info(n).c_str()); - auto shuffle_out = shuffle_layer->getOutput(0); + return true; + }}) + .pattern( + {"aten::narrow(Tensor(a) self, int dim, int start, int length) -> Tensor(a)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto axis = args[1].unwrapToInt(); + auto start = (int32_t)args[2].unwrapToInt(); + auto length = (int32_t)args[3].unwrapToInt(); - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out); + // index to access needs to be an at::Tensor + at::Tensor indices = torch::arange(start, start + length, 1).to(torch::kI32); + auto weights = Weights(ctx, indices); - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + // IConstantLayer to convert indices from Weights to ITensor + auto const_layer = ctx->net->addConstant(weights.shape, weights.data); + TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n); + auto const_out = const_layer->getOutput(0); - return true; - }}) - .pattern({"aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto axis = args[1].unwrapToInt(); - torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32); - int32_t startIdx = start.item().to(); - auto length = (int32_t)args[3].unwrapToInt(); - - // index to access needs to be an at::Tensor - at::Tensor indices = torch::arange(startIdx, startIdx + length, 1).to(torch::kI32); - auto weights = Weights(ctx, indices); - - // IConstantLayer to convert indices from Weights to ITensor - auto const_layer = ctx->net->addConstant(weights.shape, weights.data); - TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n); - auto const_out = const_layer->getOutput(0); - - // IGatherLayer takes in input tensor, the indices, and the axis - // of input tensor to take indices from - auto gather_layer = ctx->net->addGather(*in, *const_out, axis); - TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto gather_out = gather_layer->getOutput(0); - - // IShuffleLayer removes redundant dimensions - auto shuffle_layer = ctx->net->addShuffle(*gather_out); - TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); - shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions())); - shuffle_layer->setName(util::node_info(n).c_str()); - auto shuffle_out = shuffle_layer->getOutput(0); - - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out); - - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + // IGatherLayer takes in input tensor, the indices, and the axis + // of input tensor to take indices from + auto gather_layer = ctx->net->addGather(*in, *const_out, axis); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto gather_out = gather_layer->getOutput(0); - return true; - }}) + // IShuffleLayer removes redundant dimensions + auto shuffle_layer = ctx->net->addShuffle(*gather_out); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions())); + shuffle_layer->setName(util::node_info(n).c_str()); + auto shuffle_out = shuffle_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; + }}) + .pattern( + {"aten::narrow.Tensor(Tensor(a) self, int dim, Tensor start, int length) -> Tensor(a)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto axis = args[1].unwrapToInt(); + torch::Tensor start = args[2].IValue()->toTensor().to(torch::kI32); + int32_t startIdx = start.item().to(); + auto length = (int32_t)args[3].unwrapToInt(); + + // index to access needs to be an at::Tensor + at::Tensor indices = torch::arange(startIdx, startIdx + length, 1).to(torch::kI32); + auto weights = Weights(ctx, indices); + + // IConstantLayer to convert indices from Weights to ITensor + auto const_layer = ctx->net->addConstant(weights.shape, weights.data); + TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n); + auto const_out = const_layer->getOutput(0); + + // IGatherLayer takes in input tensor, the indices, and the axis + // of input tensor to take indices from + auto gather_layer = ctx->net->addGather(*in, *const_out, axis); + TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); + auto gather_out = gather_layer->getOutput(0); + + // IShuffleLayer removes redundant dimensions + auto shuffle_layer = ctx->net->addShuffle(*gather_out); + TORCHTRT_CHECK(shuffle_layer, "Unable to create shuffle layer from node: " << *n); + shuffle_layer->setReshapeDimensions(util::unpadDims(gather_out->getDimensions())); + shuffle_layer->setName(util::node_info(n).c_str()); + auto shuffle_out = shuffle_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], shuffle_out); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; + }}) .pattern( {"aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { @@ -236,29 +239,30 @@ auto select_registrations TORCHTRT_UNUSED = return true; }}) - .pattern({"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)", - [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { - auto in = args[0].ITensor(); - auto shifts = args[1].unwrapToIntList().vec(); - auto dims = args[2].unwrapToIntList().vec(); - - TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()"); - if (ctx->input_is_dynamic) { - TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation"); - } else { - auto in_shape = util::toVec(in->getDimensions()); - for (size_t i = 0; i < dims.size(); i++) { - auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i]; - TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range"); - in = roll(ctx, in, shifts[i], dim, in_shape); - } - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); - - LOG_DEBUG("Output tensor shape: " << out->getDimensions()); - - return true; - } - }}) + .pattern( + {"aten::roll(Tensor self, int[1] shifts, int[1] dims=[]) -> (Tensor)", + [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { + auto in = args[0].ITensor(); + auto shifts = args[1].unwrapToIntList().vec(); + auto dims = args[2].unwrapToIntList().vec(); + + TORCHTRT_CHECK(dims.size() == shifts.size(), "dims.size() should be equal to shifts.size()"); + if (ctx->input_is_dynamic) { + TORCHTRT_THROW_ERROR("aten::roll is currently not support in dynamic input shape compilation"); + } else { + auto in_shape = util::toVec(in->getDimensions()); + for (size_t i = 0; i < dims.size(); i++) { + auto dim = dims[i] < 0 ? (in_shape.size() + dims[i]) : dims[i]; + TORCHTRT_CHECK(dim < in_shape.size(), "Dimension out of range"); + in = roll(ctx, in, shifts[i], dim, in_shape); + } + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], in); + + LOG_DEBUG("Output tensor shape: " << out->getDimensions()); + + return true; + } + }}) .pattern( {"aten::index.Tensor(Tensor self, Tensor?[] indices) -> (Tensor)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { @@ -303,43 +307,98 @@ auto select_registrations TORCHTRT_UNUSED = {"aten::slice.Tensor(Tensor(a) self, int dim=0, int? start=None, int? end=None, int step=1) -> Tensor(a)", [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool { auto in = args[0].ITensorOrFreeze(ctx); - auto axis = args[1].unwrapToInt(); - auto maxDim = static_cast(in->getDimensions().d[axis]); - auto startIdx = 0; + int axis = args[1].unwrapToInt(); + int maxDim = static_cast(in->getDimensions().d[axis]); + bool dynamic_shape = ctx->input_is_dynamic; + auto input_dim = in->getDimensions(); + // add Shape Tensor + auto ishape_layer = ctx->net->addShape(*in); + auto ishape_tensor = ishape_layer->getOutput(0); // input shape + std::string node_name = n->outputs()[0]->debugName().c_str(); + + int startIdx = 0; auto startIdxIVal = args[2].IValue(); if (!startIdxIVal->isNone()) { - startIdx = startIdxIVal->toInt(); + startIdx = startIdxIVal->toInt() > std::numeric_limits::max() ? maxDim : startIdxIVal->toInt(); + startIdx = maxDim == -1 ? startIdx : std::min(startIdx, maxDim); } // Handle case when given tensor index is negative - auto start = (startIdx < 0) ? (maxDim + startIdx) : startIdx; + if (maxDim > 0) { // only for static shape + startIdx = (startIdx < 0) ? (maxDim + startIdx) : startIdx; + } + // Bound the end index to input tensor dimensions at specified axis - auto endIdx = maxDim; + int endIdx = maxDim; // -1 for dynamic shape auto endIdxIVal = args[3].IValue(); if (!endIdxIVal->isNone()) { - endIdx = std::min(endIdxIVal->toInt(), maxDim); + int truncate_value = endIdxIVal->toInt() > std::numeric_limits::max() ? maxDim : endIdxIVal->toInt(); + endIdx = maxDim == -1 ? truncate_value : std::min(truncate_value, maxDim); } - auto end = (endIdx < 0) ? (maxDim + endIdx) : endIdx; - auto step = args[4].unwrapToInt(); - - LOG_DEBUG("Start idx: " << start); - LOG_DEBUG("End idx: " << end); - - // indices to be accessed need to be an at::Tensor - at::Tensor indices = torch::arange(start, end, step).to(torch::kI32); - auto weights = Weights(ctx, indices); - - // IConstantLayer to convert indices from Weights to ITensor - auto const_layer = ctx->net->addConstant(weights.shape, weights.data); - TORCHTRT_CHECK(const_layer, "Unable to create constant layer from node: " << *n); - auto const_out = const_layer->getOutput(0); + if (maxDim > 0) { + endIdx = (endIdx < 0) ? (maxDim + endIdx) : endIdx; + } + int step = args[4].unwrapToInt(); + + // update start, end, stride for static shape + int nbdims = in->getDimensions().nbDims; + nvinfer1::Dims start_, size_, stride_; + start_.nbDims = nbdims; + size_.nbDims = nbdims; + stride_.nbDims = nbdims; + for (int i = 0; i < nbdims; i++) { + if (i == axis) { + start_.d[i] = startIdx; + size_.d[i] = (endIdx - startIdx - 1) / step + 1; + stride_.d[i] = step; + } else { + start_.d[i] = 0; + size_.d[i] = input_dim.d[i]; // for static + stride_.d[i] = 1; + } + } + auto slice_layer = ctx->net->addSlice(*in, start_, size_, stride_); + + if (dynamic_shape) { // dynamic shape + LOG_DEBUG("Using dynamic version of slice"); + // start tensor + at::Tensor start_tensor = torch::zeros({nbdims}).to(torch::kI32); + ; + start_tensor[axis] = startIdx; + auto start_itensor = tensor_to_const(ctx, start_tensor); + + // step tensor + at::Tensor stride_tensor = torch::ones({nbdims}).to(torch::kI32); + stride_tensor[axis] = step; + auto stride_itensor = tensor_to_const(ctx, stride_tensor); + + // end tensor + at::Tensor end_tensor = torch::zeros({nbdims}).to(torch::kI32); + for (int i = 0; i < nbdims; i++) { + if (i == axis) { + end_tensor[i] = endIdx == -1 ? -1 : endIdx - 1; + } else { + end_tensor[i] = input_dim.d[i] == -1 ? -1 : input_dim.d[i] - 1; + } + } + auto end_itensor = tensor_to_const(ctx, end_tensor); - // IGatherLayer takes in input tensor, the indices, and the axis of input tensor to take indices from - auto gather_layer = ctx->net->addGather(*in, *const_out, axis); - TORCHTRT_CHECK(gather_layer, "Unable to create gather layer from node: " << *n); - auto gather_out = gather_layer->getOutput(0); + // update start and end + nvinfer1::ITensor* out_start; + nvinfer1::ITensor* out_end; + auto start_end = normalize_start_and_end(ctx, ishape_tensor, start_itensor, end_itensor, nbdims, node_name); + out_start = start_end[0]; + out_end = start_end[1]; - auto out = ctx->AssociateValueAndTensor(n->outputs()[0], gather_out); + // calculate size + auto size_itensor = get_slice_size(ctx, out_start, out_end, stride_itensor, nbdims, node_name); + // update slice layer + slice_layer->setInput(1, *out_start); // start + slice_layer->setInput(2, *size_itensor); // size, must be set if input is dynamic + } + auto slice_out = slice_layer->getOutput(0); + + auto out = ctx->AssociateValueAndTensor(n->outputs()[0], slice_out); LOG_DEBUG("Slice layer output shape: " << out->getDimensions()); return true; diff --git a/core/conversion/var/Var.cpp b/core/conversion/var/Var.cpp index 027f8f1cb8..ff68590e3e 100644 --- a/core/conversion/var/Var.cpp +++ b/core/conversion/var/Var.cpp @@ -110,6 +110,7 @@ nvinfer1::ITensor* Var::ITensorOrFreeze(ConversionCtx* ctx) { out = ptr_.tensor; } + LOG_DEBUG("ITensor name: " << out->getName()); LOG_DEBUG("ITensor shape: " << out->getDimensions()); LOG_DEBUG("ITensor type: " << out->getType()); return out; diff --git a/tests/core/conversion/converters/test_select.cpp b/tests/core/conversion/converters/test_select.cpp index d0c9fe13bf..c4bb727d11 100644 --- a/tests/core/conversion/converters/test_select.cpp +++ b/tests/core/conversion/converters/test_select.cpp @@ -365,6 +365,208 @@ TEST(Converters, ATenSliceNegEndIndexConvertsCorrectly) { ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); } +TEST(Converters, ATenSliceDynamicBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=1]() + %end : int = prim::Constant[value=15]() + %step : int = prim::Constant[value=2]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicBatchLargeEndConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=1]() + %end : int = prim::Constant[value=9223372036854775807]() + %step : int = prim::Constant[value=2]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicNegStartBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=-15]() + %end : int = prim::Constant[value=15]() + %step : int = prim::Constant[value=2]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicNegEndBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=0]() + %start : int = prim::Constant[value=1]() + %end : int = prim::Constant[value=-2]() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicNoneBatchConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %dim : int = prim::Constant[value=0]() + %start : None = prim::Constant() + %end : None = prim::Constant() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamicConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=1]() + %start : int = prim::Constant[value=3]() + %end : int = prim::Constant[value=32]() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in dim 1, slice in dim 1 + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, false); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + +TEST(Converters, ATenSliceDynamic2ConvertsCorrectly) { + const auto graph = R"IR( + graph(%x.1 : Tensor): + %2 : None = prim::Constant() + %dim : int = prim::Constant[value=1]() + %start : int = prim::Constant[value=3]() + %end : int = prim::Constant[value=17]() + %step : int = prim::Constant[value=3]() + %9 : Tensor = aten::slice(%x.1, %dim, %start, %end, %step) + return (%9))IR"; + + auto g = std::make_shared(); + + torch::jit::parseIR(graph, g.get()); + + auto in = at::randint(1, 10, {16, 32}, {at::kCUDA}); + + auto jit_in = at::clone(in); + auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {}); + auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {jit_in}); + + auto trt_in = at::clone(in); + // dynamic shape in batch, slice in dim 1 + auto trt_results = torch_tensorrt::tests::util::RunGraphEngineDynamic(g, params, {trt_in}, true); + auto trt = trt_results[0].reshape(jit_results[0].sizes()); + + ASSERT_TRUE(torch_tensorrt::tests::util::almostEqual(jit_results[0], trt, 2e-6)); +} + TEST(Converters, ATenSplitSizesInScriptingConvertsCorrectly) { const auto graph = R"IR( graph(%x.1 : Tensor): diff --git a/tests/util/run_graph_engine.cpp b/tests/util/run_graph_engine.cpp index 04e0bd4811..c8c1a5a8a1 100644 --- a/tests/util/run_graph_engine.cpp +++ b/tests/util/run_graph_engine.cpp @@ -96,7 +96,7 @@ std::vector RunGraphEngineDynamic( bool dynamic_batch) { LOG_DEBUG("Running TRT version"); auto var_ins = get_var_inputs(g->inputs(), named_params); - auto in = core::ir::pair_input_vals_with_specs(var_ins, toInputs(inputs)); + auto in = core::ir::pair_input_vals_with_specs(var_ins, toInputsDynamic(inputs, dynamic_batch)); auto info = core::conversion::ConversionInfo(); info.inputs = std::move(in); info.engine_settings.workspace_size = (1 << 30);